diff --git a/pyproject.toml b/pyproject.toml index 6fe0fb663..5c0dd92e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ GitHub = "https://github.com/DiamondLightSource/python-murfey" [project.scripts] murfey = "murfey.client:run" "murfey.add_user" = "murfey.cli.add_user:run" +"murfey.create_config" = "murfey.cli.create_config:run" "murfey.create_db" = "murfey.cli.create_db:run" "murfey.db_sql" = "murfey.cli.murfey_db_sql:run" "murfey.decrypt_password" = "murfey.cli.decrypt_db_password:run" diff --git a/src/murfey/cli/add_user.py b/src/murfey/cli/add_user.py index 3a7696d05..55011b865 100644 --- a/src/murfey/cli/add_user.py +++ b/src/murfey/cli/add_user.py @@ -4,7 +4,7 @@ from murfey.server.api.auth import hash_password from murfey.server.murfey_db import url -from murfey.util.config import get_security_config +from murfey.util.config import get_global_config from murfey.util.db import MurfeyUser as User @@ -21,7 +21,7 @@ def run(): new_user = User( username=args.username, hashed_password=hash_password(args.password) ) - _url = url(get_security_config()) + _url = url(get_global_config()) engine = create_engine(_url) with Session(engine) as murfey_db: murfey_db.add(new_user) diff --git a/src/murfey/cli/create_config.py b/src/murfey/cli/create_config.py new file mode 100644 index 000000000..ff77f621d --- /dev/null +++ b/src/murfey/cli/create_config.py @@ -0,0 +1,1255 @@ +from __future__ import annotations + +import argparse +import json +import re +from ast import literal_eval +from pathlib import Path +from typing import Any, Callable, Optional, Type + +import yaml +from pydantic import ValidationError +from pydantic.error_wrappers import ErrorWrapper +from pydantic.fields import ModelField, UndefinedType +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + +from murfey.util.config import MachineConfig + +# Create a console object for pretty printing +console = Console() + + +def prompt(message: str, style: str = "") -> str: + """ + Helper function to pretty print a message and have the user input their response + on a new line. + """ + console.print(message, style=style) + return input("> ") + + +def print_welcome_message(): + welcome_message = Text( + "Welcome to the Murfey configuration setup tool!", style="bold bright_magenta" + ) + panel_content = Text() + panel_content.append( + "This tool will walk you through the process of setting up Murfey's " + "configuration file for your instrument, allowing you to supercharge " + "your data processing pipeline with Murfey's capacity for automated " + "data transfer and data processing coordination across devices.", + style="bright_white", + ) + panel = Panel( + panel_content, + expand=True, + ) + console.rule(welcome_message) + console.print(panel, justify="center") + console.rule() + + input("Press 'Enter' to begin the setup") + + +def print_field_info(field: ModelField): + """ + Helper function to print out the name of the key being set up, along with a short + description of what purpose the key serves. + """ + console.print() + console.print( + f"{field.name.replace('_', ' ').title()} ({field.name})", + style="bold bright_cyan", + ) + console.print(field.field_info.description, style="bright_white") + if not isinstance(field.field_info.default, UndefinedType): + console.print(f"Default: {field.field_info.default!r}", style="bold white") + + +def ask_for_permission(message: str) -> bool: + """ + Helper function to generate a Boolean based on user input + """ + while True: + answer = prompt(message, style="bright_yellow").lower().strip() + if answer in ("y", "yes"): + return True + if answer in ("n", "no"): + return False + console.print("Invalid input. Please try again.", style="bright_red") + continue + + +def ask_for_input(parameter: str, again: bool = False): + """ + Asks the user if another value should be entered into the current data structure. + """ + message = ( + "Would you like to add " + + ( + "another" + if again + else ( + "an" if parameter.lower().startswith(("a", "e", "i", "o", "u")) else "a" + ) + ) + + f" {parameter}? [bold bright_magenta](y/n)[/bold bright_magenta]" + ) + return ask_for_permission(message) + + +def ask_to_use_default(field: ModelField): + """ + Asks the user if they want to populate the current configuration key with the + default value. + """ + message = ( + "Would you like to use the default value for this field? " + "[bold bright_magenta](y/n)[/bold bright_magenta] \n" + f"{field.field_info.default!r}" + ) + return ask_for_permission(message) + + +def confirm_overwrite(value: str): + """ + Asks the user if a value that already exists should be overwritten. + """ + message = f"{value!r} already exists; do you wish to overwrite it? [bold bright_magenta](y/n)[/bold bright_magenta]" + return ask_for_permission(message) + + +def confirm_duplicate(value: str): + """ + Asks the user if a duplicate value should be allowed. + """ + message = f"{value!r} already exists; do you want to add a duplicate? [bold bright_magenta](y/n)[/bold bright_magenta]" + return ask_for_permission(message) + + +def get_folder_name(message: Optional[str] = None) -> str: + """ + Helper function to interactively generate, validate, and return a folder name. + """ + while True: + message = "Please enter the folder name." if message is None else message + value = prompt(message, style="bright_yellow").strip() + if bool(re.fullmatch(r"[\w\s\-]*", value)) is True: + return value + console.print( + "There are unsafe characters present in this folder name. Please " + "use a different one.", + style="bright_red", + ) + if ask_for_input("folder name", True) is False: + return "" + continue + + +def get_folder_path(message: Optional[str] = None) -> Path | None: + """ + Helper function to interactively generate, validate, and return the full path + to a folder. + """ + while True: + message = ( + "Please enter the full path to the folder." if message is None else message + ) + value = prompt(message, style="bright_yellow").strip() + if not value: + return None + try: + path = Path(value).resolve() + return path + except Exception: + console.print("Unable to resolve provided file path", style="bright_red") + if ask_for_input("file path", True) is False: + return None + continue + + +def get_file_path(message: Optional[str] = None) -> Path | None: + """ + Helper function to interactively generate, validate, and return the full path + to a file. + """ + while True: + message = ( + "Please enter the full path to the file." if message is None else message + ) + value = prompt(message, style="bright_yellow").strip() + if not value: + return None + file = Path(value).resolve() + if file.suffix: + return file + console.print(f"{str(file)!r} doesn't appear to be a file", style="bright_red") + if ask_for_input("file", True) is False: + return None + continue + + +def construct_list( + value_name: str, + value_method: Optional[Callable] = None, + value_method_args: dict = {}, + allow_empty: bool = False, + allow_eval: bool = True, + many_types: bool = True, + restrict_to_types: Optional[Type[Any] | tuple[Type[Any]]] = None, + sort_values: bool = True, + debug: bool = False, +) -> list[Any]: + """ + Helper function to facilitate interactive construction of a list. + """ + lst: list = [] + add_entry = ask_for_input(value_name, False) + while add_entry is True: + value = ( + prompt( + "Please enter " + + ("an" if value_name.startswith(("a", "e", "i", "o", "u")) else "a") + + f" {value_name}", + style="bright_yellow", + ) + if value_method is None + else value_method(**value_method_args) + ) + # Reject empty inputs if set + if not value and not allow_empty: + console.print("No value provided.", style="bright_red") + add_entry = ask_for_input(value_name, True) + continue + # Convert values if set + try: + eval_value = literal_eval(value) if allow_eval else value + except Exception: + eval_value = value + # Check if it's a permitted type (continue to allow None as value) + if restrict_to_types is not None: + allowed_types = ( + (restrict_to_types,) + if not isinstance(restrict_to_types, (list, tuple)) + else restrict_to_types + ) + if not isinstance(eval_value, allowed_types): + console.print( + f"The provided value ({type(eval_value)}) is not an allowed type.", + style="bright_red", + ) + add_entry = ask_for_input(value_name, True) + continue + # Confirm if duplicate entry should be added + if eval_value in lst and confirm_duplicate(str(eval_value)) is False: + add_entry = ask_for_input(value_name, True) + continue + lst.append(eval_value) + # Reject list with multiple types if set + if not many_types and len({type(item) for item in lst}) > 1: + console.print( + "The provided value is of a different type to the other members. It " + "won't be added to the list.", + style="bright_red", + ) + lst = lst[:-1] + # Sort values if set + # Sort numeric values differently from alphanumeric ones + lst = ( + sorted( + lst, + key=lambda v: ( + (0, float(v), 0) + if isinstance(v, (int, float)) + else ( + (1, abs(v), v.real) + if isinstance(v, complex) + else (2, str(v), "") + ) + ), + ) + if sort_values + else lst + ) + add_entry = ask_for_input(value_name, True) + continue + return lst + + +def construct_dict( + dict_name: str, + key_name: str, + value_name: str, + key_method: Optional[Callable] = None, + key_method_args: dict = {}, + value_method: Optional[Callable] = None, + value_method_args: dict = {}, + allow_empty_key: bool = True, + allow_empty_value: bool = True, + allow_eval: bool = True, + sort_keys: bool = True, + restrict_to_types: Optional[Type[Any] | tuple[Type[Any], ...]] = None, + debug: bool = False, +) -> dict[str, Any]: + """ + Helper function to facilitate the interative construction of a dictionary. + """ + + def is_type(value: str, instance: Type[Any] | tuple[Type[Any], ...]) -> bool: + """ + Checks if the string provided evaluates to one of the desired types + """ + instance = (instance,) if not isinstance(instance, (list, tuple)) else instance + try: + eval_value = literal_eval(value) + except Exception: + eval_value = value + return isinstance(eval_value, instance) + + dct: dict = {} + add_entry = ask_for_input(dict_name, False) + key_message = f"Please enter the {key_name}" + value_message = f"Please enter the {value_name}" + while add_entry is True: + # Add key + key = str( + prompt(key_message, style="bright_yellow").strip() + if key_method is None + else key_method(**key_method_args) + ) + # Reject empty keys if set + if not allow_empty_key and not key: + console.print(f"No {key_name} provided.", style="bright_red") + add_entry = ask_for_input(dict_name, True) + continue + # Confirm overwrite key on duplicate + if key in dct.keys(): + if confirm_overwrite(key) is False: + add_entry = ask_for_input(dict_name, True) + continue + # Add value + value = ( + prompt(value_message, style="bright_yellow").strip() + if value_method is None + else value_method(**value_method_args) + ) + # Reject empty values if set + if not allow_empty_value and not value: + console.print("No value provided", style="bright_red") + add_entry = ask_for_input(dict_name, True) + continue + # Convert values if set + try: + eval_value = literal_eval(value) if allow_eval else value + except Exception: + eval_value = value + # Reject incorrect value types if set + if restrict_to_types is not None: + allowed_types = ( + (restrict_to_types,) + if not isinstance(restrict_to_types, (tuple, list)) + else restrict_to_types + ) + if not isinstance(eval_value, allowed_types): + console.print( + "The value is not of an allowed type.", style="bright_red" + ) + add_entry = ask_for_input(dict_name, True) + continue + # Assign value to key + dct[key] = eval_value + add_entry = ask_for_input(dict_name, True) + continue + + # Sort keys if set + # Sort numeric keys separately from alphanumeric ones + dct = ( + { + key: dct[key] + for key in sorted( + dct.keys(), + key=lambda k: ( + (0, float(k), 0) + if is_type(k, (int, float)) + else ( + (1, abs(complex(k)), complex(k).real) + if is_type(k, complex) + else (2, str(k), "") + ) + ), + ) + } + if sort_keys + else dct + ) + return dct + + +def validate_value(value: Any, key: str, field: ModelField, debug: bool = False) -> Any: + """ + Helper function to validate the value of a field in the Pydantic model. + """ + validated_value, errors = field.validate(value, {}, loc=key) + if errors: + raise ValidationError( + ([errors] if isinstance(errors, ErrorWrapper) else errors), MachineConfig + ) + console.print(f"{key!r} validated successfully.", style="bright_green") + if debug: + console.print(f"Type: {type(validated_value)}", style="bright_green") + console.print(f"{validated_value!r}", style="bright_green") + return validated_value + + +def populate_field(key: str, field: ModelField, debug: bool = False) -> Any: + """ + General function for inputting and validating the value of a single field against + its Pydantic model. + """ + + # Display information on the field to be filled + print_field_info(field) + + defaults_prompt = ( + f"press Enter to use the default value of {field.field_info.default!r}" + if not isinstance(field.field_info.default, UndefinedType) + else "this field is mandatory" + ) + message = f"Please provide a value ({defaults_prompt})." + while True: + # Get value + answer = prompt(message, style="bright_yellow") + + # Parse field input if a default has been provided + if not isinstance(field.field_info.default, UndefinedType): + # Convert empty console inputs into default field values + if not answer: + value = field.field_info.default + # Convert inverted commas into empty strings + elif answer in ('""', "''") and isinstance(field.field_info.default, str): + value = "" + else: + value = answer + else: + value = answer + + # Validate and return + try: + return validate_value(value, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print( + f"Invalid input for {key!r}. Please try again", style="bright_red" + ) + continue + + +def add_calibrations( + key: str, field: ModelField, debug: bool = False +) -> dict[str, dict]: + """ + Populate the 'calibrations' field with dictionaries. + """ + # Known calibrations and what to call their keys and values + known_calibrations: dict[str, tuple[str, str]] = { + # Calibration type | Key name | Value name + "magnification": ("magnification", "pixel size (in angstroms)") + } + + # Start of add_calibrations + print_field_info(field) + category = "calibration setting" + calibrations: dict = {} + add_calibration = ask_for_input(category, False) + while add_calibration is True: + calibration_type = prompt( + "What type of calibration settings are you providing?", + style="bright_yellow", + ).lower() + # Check if it's a known type of calibration + if calibration_type not in known_calibrations.keys(): + console.print( + f"{calibration_type!r} is not a known type of calibration", + style="bright_red", + ) + add_calibration = ask_for_input(category, True) + continue + # Handle duplicate keys + if calibration_type in calibrations.keys(): + if confirm_overwrite(calibration_type) is False: + add_calibration = ask_for_input(category, True) + continue + # Skip failed inputs + calibration_values = construct_dict( + f"{calibration_type} calibration", + known_calibrations[calibration_type][0], + known_calibrations[calibration_type][1], + allow_empty_key=False, + allow_empty_value=False, + allow_eval=True, + sort_keys=True, + ) + if not calibration_values: + add_calibration = ask_for_input(category, True) + continue + + # Add calibration to master dict + calibrations[calibration_type] = calibration_values + console.print( + f"Added {calibration_type} to the calibrations field", + style="bright_green", + ) + if debug: + console.print(f"{calibration_values}", style="bright_green") + + # Check if any more calibrations need to be added + add_calibration = ask_for_input("calibration setting", again=True) + + # Validate the nested dictionary structure + try: + return validate_value(calibrations, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return add_calibrations(key, field, debug) + console.print("Returning an empty dictionary", style="bright_red") + return {} + + +def add_software_packages(config: dict, debug: bool = False) -> dict[str, Any]: + def get_software_name() -> str: + """ + Function to interactively generate, validate, and return the name of a + supported software package. + """ + message = ( + "What is the name of the software package? Supported options: 'autotem', " + "'epu', 'leica', 'serialem', 'tomo'" + ) + name = prompt(message, style="bright_yellow").lower().strip() + # Validate name against "acquisition_software" field + try: + field = MachineConfig.__fields__["acquisition_software"] + return validate_value([name], "acquisition_software", field, False)[0] + except ValidationError: + console.print("Invalid software name.", style="bright_red") + if ask_for_input("software package", True) is True: + return get_software_name() + console.print("Returning an empty string.", style="bright_red") + return "" + + def ask_about_settings_file() -> bool: + message = ( + "Does this software package have a settings file that needs modification? " + "[bold bright_magenta](y/n)[/bold bright_magenta]" + ) + return ask_for_permission(message) + + def get_settings_tree_path() -> str: + message = "What is the path through the XML file to the node to overwrite?" + xml_tree_path = prompt(message, style="bright_yellow").strip() + # TODO: Currently no test cases for this method + return xml_tree_path + + """ + Start of add_software_packages + """ + console.print() + console.print( + "Acquisition Software (acquisition_software)", + style="bold bright_cyan", + ) + console.print( + "This is where aquisition software packages present on the instrument machine " + "can be specified, along with the output file names and extensions that are of " + "interest.", + style="bright_white", + ) + package_info: dict = {} + category = "software package" + add_input = ask_for_input(category, again=False) + while add_input: + # Collect software name + console.print( + "Acquisition Software (acquisition_software)", + style="bold bright_cyan", + ) + console.print( + "Name of the acquisition software installed on this instrument.", + style="bright_white", + ) + console.print( + "Options: 'autotem', 'epu', 'leica', 'serialem', 'tomo'", + style="bright_cyan", + ) + name = get_software_name() + if name in package_info.keys(): + if confirm_overwrite(name) is False: + add_input = ask_for_input(category, False) + continue + + # Collect version info + console.print( + "Software Versions (software_versions)", + style="bold bright_cyan", + ) + version = prompt( + "What is the version number of this software package? Press Enter to leave " + "it blank if you're unsure.", + style="bright_yellow", + ) + + # Collect settings files and modifications + console.print( + "Software Settings Output Directories (software_settings_output_directories)", + style="bold bright_cyan", + ) + console.print( + "Some software packages will have settings files that require modification " + "in order to ensure files are saved to the desired folders. The paths to " + "the files and the path to the nodes in the settings files both need to be " + "provided.", + style="bright_white", + ) + settings_file: Optional[Path] = ( + get_file_path( + "What is the full path to the settings file? This is usually an XML file." + ) + if ask_about_settings_file() is True + else None + ) + settings_tree_path = ( + get_settings_tree_path().split("/") if settings_file else [] + ) + + # Collect extensions and filename substrings + console.print( + "Data Required Substrings (data_required_substrings)", + style="bold bright_cyan", + ) + console.print( + "Different software packages will generate different output files. Only " + "files with certain extensions and keywords in their filenames are needed " + "for data processing. They are listed out here.", + style="bright_white", + ) + extensions_and_substrings: dict[str, list[str]] = construct_dict( + dict_name="file extension configuration", + key_name="file extension", + value_name="file substrings", + value_method=construct_list, + value_method_args={ + "value_name": "file substring", + "allow_empty": False, + "allow_eval": False, + "many_types": False, + "restrict_to_types": str, + "sort_values": True, + }, + allow_empty_key=False, + allow_empty_value=False, + allow_eval=False, + sort_keys=True, + restrict_to_types=list, + ) + + # Compile keys for this package as a dict + package_info[name] = { + "version": version, + "settings_file": settings_file, + "settings_tree_path": settings_tree_path, + "extensions_and_substrings": extensions_and_substrings, + } + add_input = ask_for_input(category, again=True) + continue + + # Re-pack keys and values according to the current config field structures + console.print("Compiling and validating inputs...") + acquisition_software: list = [] + software_versions: dict = {} + software_settings_output_directories: dict = {} + data_required_substrings: dict = {} + + # Add keys after sorting + for key in sorted(package_info.keys()): + acquisition_software.append(key) + if package_info[key]["version"]: + software_versions[key] = package_info[key]["version"] + if package_info[key]["settings_file"]: + software_settings_output_directories[ + str(package_info[key]["settings_file"]) + ] = package_info[key]["settings_tree_path"] + if package_info[key]["extensions_and_substrings"]: + data_required_substrings[key] = package_info[key][ + "extensions_and_substrings" + ] + + # Validate against their respective fields + to_validate = ( + ("acquisition_software", acquisition_software), + ("software_versions", software_versions), + ("software_settings_output_directories", software_settings_output_directories), + ("data_required_substrings", data_required_substrings), + ) + for field_name, value in to_validate: + try: + field = MachineConfig.__fields__[field_name] + config[field_name] = validate_value(value, field_name, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {field_name!r}", style="bright_red") + if ask_for_input("software package configuration", True) is True: + return add_software_packages(config) + console.print(f"Skipped adding {field_name!r}.", style="bright_red") + + # Return updated dictionary + return config + + +def add_data_directories( + key: str, field: ModelField, debug: bool = False +) -> list[Path]: + """ + Function to facilitate populating the data_directories field. + """ + print_field_info(field) + description = "data directory path" + data_directories: list[Path] = construct_list( + description, + value_method=get_folder_path, + value_method_args={ + "message": ( + "Please enter the full path to the data directory " + "where the files are stored." + ), + }, + allow_empty=False, + allow_eval=False, + many_types=False, + restrict_to_types=Path, + sort_values=True, + debug=debug, + ) + + # Validate and return + try: + return validate_value(data_directories, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(description, True) is True: + return add_data_directories(key, field, debug) + console.print("Returning an empty dictionary.", style="bright_red") + return [] + + +def add_create_directories( + key: str, field: ModelField, debug: bool = False +) -> dict[str, str]: + """ + Function to populate the create_directories field. + """ + print_field_info(field) + + # Manually enter fields if default value is not used + description = "folder for Murfey to create" + folders_to_create: dict[str, str] = ( + field.field_info.default + if ask_to_use_default(field) is True + else construct_dict( + dict_name=description, + key_name="folder alias", + value_name="folder name", + key_method=get_folder_name, + key_method_args={ + "message": "Please enter the name Murfey should remember the folder as.", + }, + value_method=get_folder_name, + value_method_args={ + "message": "Please enter the name of the folder for Murfey to create.", + }, + allow_empty_key=False, + allow_empty_value=False, + allow_eval=False, + sort_keys=True, + restrict_to_types=str, + ) + ) + + # Validate and return + try: + return validate_value(folders_to_create, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(description, True) is True: + return add_create_directories(key, field, debug) + console.print("Returning an empty dictionary.", style="bright_red") + return {} + + +def add_analyse_created_directories( + key: str, field: ModelField, debug: bool = False +) -> list[str]: + """ + Function to populate the analyse_created_directories field + """ + print_field_info(field) + category = "folder for Murfey to analyse" + + folders_to_analyse: list[str] = construct_list( + value_name=category, + value_method=get_folder_name, + value_method_args={ + "message": "Please enter the name of the folder that Murfey is to analyse." + }, + allow_empty=False, + allow_eval=False, + many_types=False, + restrict_to_types=str, + sort_values=True, + ) + + # Validate and return + try: + return sorted(validate_value(folders_to_analyse, key, field, debug)) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return add_analyse_created_directories(key, field, debug) + console.print("Returning an empty list.", style="bright_red") + return [] + + +def set_up_data_transfer(config: dict, debug: bool = False) -> dict: + """ + Helper function to set up the data transfer fields in the configuration + """ + + def get_upstream_data_directories( + key: str, field: ModelField, debug: bool = False + ) -> list[Path]: + print_field_info(field) + category = "upstream data directory" + upstream_data_directories = construct_list( + category, + value_method=get_folder_path, + value_method_args={ + "message": ( + "Please enter the full path to the data directory " + "you wish to search for files in." + ), + }, + allow_empty=False, + allow_eval=False, + many_types=False, + restrict_to_types=Path, + sort_values=True, + ) + try: + return validate_value(upstream_data_directories, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return get_upstream_data_directories(key, field, debug) + console.print("Returning an empty list.", style="bright_red") + return [] + + def get_upstream_data_tiff_locations( + key: str, field: ModelField, debug: bool = False + ) -> list[str]: + print_field_info(field) + category = "remote folder containing TIFF files" + upstream_data_tiff_locations = construct_list( + category, + value_method=get_folder_name, + value_method_args={ + "message": ( + "Please enter the name of the folder on the remote machines " + "in which to search for TIFF files." + ) + }, + allow_empty=False, + allow_eval=False, + many_types=False, + restrict_to_types=str, + sort_values=True, + ) + try: + return validate_value(upstream_data_tiff_locations, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return get_upstream_data_tiff_locations(key, field, debug) + console.print("Returning an empty list.", style="bright_red") + return [] + + """ + Start of set_up_data_transfer + """ + for key in ( + "data_transfer_enabled", + "rsync_basepath", + "rsync_module", + "rsync_url", + "allow_removal", + "upstream_data_directories", + "upstream_data_download_directory", + "upstream_data_tiff_locations", + ): + field = MachineConfig.__fields__[key] + # Skip everything in this section if data transfer is set to False + if config.get("data_transfer_enabled", None) is False: + continue + # Construct more complicated data structures + if key == "upstream_data_directories": + validated_value: Any = get_upstream_data_directories(key, field, debug) + elif key == "upstream_data_tiff_locations": + validated_value = get_upstream_data_tiff_locations(key, field, debug) + # Use populate field to process simpler keys + else: + validated_value = populate_field(key, field, debug) + + # Add to config + config[key] = validated_value + + return config + + +def set_up_data_processing(config: dict, debug: bool = False) -> dict: + """ + Helper function to set up the data processing fields in the config. + """ + + def add_recipes(key: str, field: ModelField, debug: bool = False) -> dict[str, str]: + print_field_info(field) + + # Manually construct the dictionary if the default value is not used + category = "processing recipe" + recipes = ( + field.field_info.default + if ask_to_use_default(field) is True + else construct_dict( + category, + key_name="name of the recipe", + value_name="name of the recipe file", + allow_empty_key=False, + allow_empty_value=False, + allow_eval=False, + sort_keys=True, + restrict_to_types=str, + ) + ) + try: + return validate_value(recipes, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return add_recipes(key, field, debug) + console.print("Returning an empty dictionary.", style="bright_red") + return {} + + """ + Start of set_up_data_processing + """ + # Process in order + for key in ( + "processing_enabled", + "process_by_default", + "gain_directory_name", + "processed_directory_name", + "processed_extra_directory", + "recipes", + "default_model", + "model_search_directory", + "initial_model_search_directory", + ): + field = MachineConfig.__fields__[key] + # Skip this section of processing is disabled + if config.get("processing_enabled", None) is False: + continue + # Handle complex keys + if key == "recipes": + validated_value: Any = add_recipes(key, field, debug) + # Populate fields of simple keys + else: + validated_value = populate_field(key, field, debug) + config[key] = validated_value + + return config + + +def add_external_executables( + key: str, field: ModelField, debug: bool = False +) -> dict[str, Path]: + print_field_info(field) + category = "external executable" + external_executables = construct_dict( + dict_name=category, + key_name="name of the executable", + value_name="full file path to the executable", + value_method=get_folder_path, + value_method_args={ + "message": ("Please enter the full file path to the executable"), + }, + allow_empty_key=False, + allow_empty_value=False, + allow_eval=False, + sort_keys=True, + restrict_to_types=Path, + ) + try: + return validate_value(external_executables, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return add_external_executables(key, field, debug) + console.print("Returning an empty dictionary.", style="bright_red") + return {} + + +def add_external_environment( + key: str, field: ModelField, debug: bool = False +) -> dict[str, str]: + print_field_info(field) + category = "external environment" + external_environment = construct_dict( + dict_name=category, + key_name="name of the environment", + value_name="full path to the folder", + allow_empty_key=False, + allow_empty_value=False, + allow_eval=False, + sort_keys=True, + restrict_to_types=str, + ) + try: + return validate_value(external_environment, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return add_external_environment(key, field, debug) + console.print("Returning an empty dictionary.", style="bright_red") + return {} + + +def add_murfey_plugins(key: str, field: ModelField, debug: bool = False) -> dict: + """ + Helper function to set up the Murfey plugins field in the config. + """ + print_field_info(field) + category = "Murfey plugin package" + plugins = construct_dict( + dict_name=category, + key_name="name of the plugin", + value_name="full file path to the plugin", + value_method=get_file_path, + value_method_args={ + "message": "Please enter the full file path to the plugin.", + }, + allow_empty_key=False, + allow_empty_value=False, + allow_eval=False, + sort_keys=True, + restrict_to_types=Path, + ) + try: + return validate_value(plugins, key, field, debug) + except ValidationError as error: + if debug: + console.print(error, style="bright_red") + console.print(f"Failed to validate {key!r}.", style="bright_red") + if ask_for_input(category, True) is True: + return add_murfey_plugins(key, field, debug) + console.print("Returning an empty dictionary.", style="bright_red") + return {} + + +def set_up_machine_config(debug: bool = False): + """ + Main function which runs through the setup process. + """ + + print_welcome_message() + + new_config: dict = {} + for key, field in MachineConfig.__fields__.items(): + """ + Logic for complicated or related fields + """ + if key == "superres": + camera: str = new_config["camera"] + new_config[key] = True if camera.lower().startswith("gatan") else False + continue + if key == "calibrations": + new_config[key] = add_calibrations(key, field, debug) + continue + + # Acquisition software section + if key == "acquisition_software": + new_config = add_software_packages(new_config, debug) + continue + + if key == "data_directories": + new_config[key] = add_data_directories(key, field, debug) + continue + if key == "create_directories": + new_config[key] = add_create_directories(key, field, debug) + continue + if key == "analyse_created_directories": + new_config[key] = add_analyse_created_directories(key, field, debug) + continue + + # Data transfer section + if key == "data_transfer_enabled": + new_config = set_up_data_transfer(new_config, debug) + continue + + # Data processing section + if key == "processing_enabled": + new_config = set_up_data_processing(new_config, debug) + continue + + # External plugins and executables section + if key in ("external_executables", "external_executables_eer"): + new_config[key] = add_external_executables(key, field, debug) + continue + if key == "external_environment": + new_config[key] = add_external_environment(key, field, debug) + continue + + if key == "plugin_packages": + new_config[key] = add_murfey_plugins(key, field, debug) + continue + + # All the keys that can be skipped + if key in ( + # Acquisition software section + "software_versions", + "software_settings_output_directories", + "data_required_substrings", + # Data transfer section + "allow_removal", + "rsync_basepath", + "rsync_module", + "rsync_url", + "upstream_data_directories", + "upstream_data_download_directory", + "upstream_data_tiff_locations", + # Data processing section + "process_by_default", + "gain_directory_name", + "processed_directory_name", + "processed_extra_directory", + "recipes", + "modular_spa", + "default_model", + "model_search_directory", + "initial_model_search_directory", + ): + continue + + """ + Standard method of inputting values + """ + new_config[key] = populate_field(key, field, debug) + + # Validate the entire config and convert into JSON/YAML-safe dict + try: + new_config_safe: dict = json.loads(MachineConfig(**new_config).json()) + except ValidationError as exception: + # Print out validation errors found + console.print("Validation failed", style="bright_red") + for error in exception.errors(): + console.print(f"{error}", style="bright_red") + # Offer to redo the setup, otherwise quit setup + if ask_for_input("machine configuration", True) is True: + return set_up_machine_config(debug) + return False + + # Save config under its instrument name + master_config: dict[str, dict] = { + new_config_safe["instrument_name"]: new_config_safe + } + + # Create save path for config + console.print("Machine config successfully validated.", style="green") + config_name = prompt( + "What would you like to name the file? (E.g. 'my_machine_config')", + style="bright_yellow", + ) + config_path = Path( + prompt("Where would you like to save this config?", style="bright_yellow") + ) + config_file = config_path / f"{config_name}.yaml" + config_path.mkdir(parents=True, exist_ok=True) + + # Check if config file already exists at the location + if config_file.exists(): + with open(config_file) as existing_file: + try: + old_config: dict[str, dict] = yaml.safe_load(existing_file) + except yaml.YAMLError as error: + console.print(error, style="bright_red") + # Provide option to quit or try again + if ask_for_input("machine configuration", True) is True: + return set_up_machine_config(debug) + console.print("Exiting machine configuration setup guide") + exit() + # Check if settings already exist for this machine + for key in master_config.keys(): + # Check if overwriting of existing config is needed + if key in old_config.keys() and confirm_overwrite(key) is False: + old_config[key].update(master_config[key]) + # Add new machine config + else: + old_config[key] = master_config[key] + # Regenerate dictionary and store machine configs alphabetically + master_config = {key: old_config[key] for key in sorted(old_config.keys())} + with open(config_file, "w") as save_file: + yaml.dump(master_config, save_file, default_flow_style=False, sort_keys=False) + console.print( + f"Machine configuration for {new_config_safe['instrument_name']!r} " + f"successfully saved as {str(config_file)!r}", + style="bright_green", + ) + console.print("Machine configuration complete", style="bright_green") + + # Provide option to set up another machine configuration + if ask_for_input("machine configuration", True) is True: + return set_up_machine_config(debug) + console.print("Exiting machine configuration setup guide", style="bright_green") + return True + + +def run(): + # Set up arg parser + parser = argparse.ArgumentParser() + parser.add_argument( + "--debug", + action="store_true", + help="Prints additional messages to show setup progress.", + ) + args = parser.parse_args() + + set_up_machine_config(args.debug) diff --git a/src/murfey/cli/decrypt_db_password.py b/src/murfey/cli/decrypt_db_password.py index 0e019a1d7..ff6281739 100644 --- a/src/murfey/cli/decrypt_db_password.py +++ b/src/murfey/cli/decrypt_db_password.py @@ -2,7 +2,7 @@ from cryptography.fernet import Fernet -from murfey.util.config import get_security_config +from murfey.util.config import get_global_config def run(): @@ -12,6 +12,6 @@ def run(): args = parser.parse_args() - security_config = get_security_config() - f = Fernet(security_config.crypto_key.encode("ascii")) + global_config = get_global_config() + f = Fernet(global_config.crypto_key.encode("ascii")) print(f.decrypt(args.password.encode("ascii")).decode()) diff --git a/src/murfey/cli/generate_db_password.py b/src/murfey/cli/generate_db_password.py index 431ede7e7..ba9f07d42 100644 --- a/src/murfey/cli/generate_db_password.py +++ b/src/murfey/cli/generate_db_password.py @@ -3,12 +3,12 @@ from cryptography.fernet import Fernet -from murfey.util.config import get_security_config +from murfey.util.config import get_global_config def run(): - security_config = get_security_config() - f = Fernet(security_config.crypto_key.encode("ascii")) + global_config = get_global_config() + f = Fernet(global_config.crypto_key.encode("ascii")) alphabet = string.ascii_letters + string.digits password = "".join(secrets.choice(alphabet) for i in range(32)) print(f.encrypt(password.encode("ascii")).decode()) diff --git a/src/murfey/cli/inject_spa_processing.py b/src/murfey/cli/inject_spa_processing.py index 2294835a5..5f0ba5dcd 100644 --- a/src/murfey/cli/inject_spa_processing.py +++ b/src/murfey/cli/inject_spa_processing.py @@ -10,7 +10,7 @@ from murfey.server.ispyb import TransportManager from murfey.server.murfey_db import url -from murfey.util.config import get_machine_config, get_microscope, get_security_config +from murfey.util.config import get_global_config, get_machine_config, get_microscope from murfey.util.db import ( AutoProcProgram, ClientEnvironment, @@ -97,13 +97,13 @@ def run(): os.environ["BEAMLINE"] = args.microscope machine_config = get_machine_config() - security_config = get_security_config() + global_config = get_global_config() _url = url(machine_config) engine = create_engine(_url) murfey_db = Session(engine) _transport_object = TransportManager(args.transport) - _transport_object.feedback_queue = security_config.feedback_queue + _transport_object.feedback_queue = global_config.feedback_queue query = ( select(Movie) diff --git a/src/murfey/cli/spa_ispyb_messages.py b/src/murfey/cli/spa_ispyb_messages.py index 6c54d5e2f..5f8484a62 100644 --- a/src/murfey/cli/spa_ispyb_messages.py +++ b/src/murfey/cli/spa_ispyb_messages.py @@ -22,7 +22,7 @@ from murfey.server.ispyb import Session, TransportManager, get_session_id from murfey.server.murfey_db import url from murfey.util import db -from murfey.util.config import get_machine_config, get_microscope, get_security_config +from murfey.util.config import get_global_config, get_machine_config, get_microscope def run(): @@ -341,7 +341,7 @@ def run(): .where(db.ProcessingJob.recipe == "em-spa-preprocess") ).one() machine_config = get_machine_config() - security_config = get_security_config() + global_config = get_global_config() params = db.SPARelionParameters( pj_id=collected_ids[2].id, angpix=float(metadata["pixel_size_on_image"]) * 1e10, @@ -378,7 +378,7 @@ def run(): if args.flush_preprocess: _transport_object = TransportManager(args.transport) - _transport_object.feedback_queue = security_config.feedback_queue + _transport_object.feedback_queue = global_config.feedback_queue stashed_files = murfey_db.exec( select(db.PreprocessStash) .where(db.PreprocessStash.session_id == args.session_id) diff --git a/src/murfey/server/__init__.py b/src/murfey/server/__init__.py index 335525851..a2675f5f9 100644 --- a/src/murfey/server/__init__.py +++ b/src/murfey/server/__init__.py @@ -53,10 +53,10 @@ from murfey.util import LogFilter from murfey.util.config import ( MachineConfig, + get_global_config, get_hostname, get_machine_config, get_microscope, - get_security_config, ) from murfey.util.processing_params import default_spa_parameters from murfey.util.state import global_state @@ -76,7 +76,7 @@ _transport_object: TransportManager | None = None try: - _url = url(get_security_config()) + _url = url(get_global_config()) engine = create_engine(_url) murfey_db = Session(engine, expire_on_commit=False) except Exception: @@ -278,12 +278,12 @@ def run(): args, unknown = parser.parse_known_args() # Load the security configuration - security_config = get_security_config() + global_config = get_global_config() # Set up GrayLog handler if provided in the configuration - if security_config.graylog_host: + if global_config.graylog_host: handler = graypy.GELFUDPHandler( - security_config.graylog_host, security_config.graylog_port, level_names=True + global_config.graylog_host, global_config.graylog_port, level_names=True ) root_logger = logging.getLogger() root_logger.addHandler(handler) @@ -294,15 +294,18 @@ def run(): # Run in demo mode with no connections set up os.environ["MURFEY_DEMO"] = "1" else: + if not global_config.rabbitmq_credentials: + raise FileNotFoundError("No RabbitMQ credentials file provided") # Load RabbitMQ configuration and set up the connection - PikaTransport().load_configuration_file(security_config.rabbitmq_credentials) + PikaTransport().load_configuration_file(global_config.rabbitmq_credentials) _set_up_transport("PikaTransport") # Set up logging now that the desired verbosity is known _set_up_logging(quiet=args.quiet, verbosity=args.verbose) + global_config = get_global_config() if not args.temporary and _transport_object: - _transport_object.feedback_queue = security_config.feedback_queue + _transport_object.feedback_queue = global_config.feedback_queue rabbit_thread = Thread( target=feedback_listen, daemon=True, @@ -1688,13 +1691,13 @@ def _resize_intial_model( downscaled_pixel_size: float, input_path: Path, output_path: Path, - executables: Dict[str, str], + executables: Dict[str, Path], env: Dict[str, str], ) -> None: if executables.get("relion_image_handler"): comp_proc = subprocess.run( [ - f"{executables['relion_image_handler']}", + f"{str(executables['relion_image_handler'])}", "--i", str(input_path), "--new_box", diff --git a/src/murfey/server/api/__init__.py b/src/murfey/server/api/__init__.py index 2a19c7153..7c204b85f 100644 --- a/src/murfey/server/api/__init__.py +++ b/src/murfey/server/api/__init__.py @@ -52,7 +52,7 @@ from murfey.server.gain import Camera, prepare_eer_gain, prepare_gain from murfey.server.murfey_db import murfey_db from murfey.util import secure_path -from murfey.util.config import MachineConfig, from_file, settings +from murfey.util.config import MachineConfig, machine_config_from_file, settings from murfey.util.db import ( AutoProcProgram, ClientEnvironment, @@ -147,9 +147,9 @@ def connections_check(): def machine_info() -> Optional[MachineConfig]: instrument_name = os.getenv("BEAMLINE") if settings.murfey_machine_configuration and instrument_name: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] + return machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name + )[instrument_name] return None @@ -157,9 +157,9 @@ def machine_info() -> Optional[MachineConfig]: @router.get("/instruments/{instrument_name}/machine") def machine_info_by_name(instrument_name: str) -> Optional[MachineConfig]: if settings.murfey_machine_configuration: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] + return machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name + )[instrument_name] return None @@ -1268,6 +1268,10 @@ async def request_tomography_preprocessing( murfey_ids = _murfey_id(appid, db, number=1, close=False) if not mrc_out.parent.exists(): mrc_out.parent.mkdir(parents=True, exist_ok=True) + # Handle case when gain reference file is None + if not proc_file.gain_ref: + log.error("No gain reference file was provided in the ProcessFile object") + return proc_file zocalo_message: dict = { "recipes": ["em-tomo-preprocess"], "parameters": { @@ -1286,7 +1290,9 @@ async def request_tomography_preprocessing( "fm_dose": proc_file.dose_per_frame, "gain_ref": ( str(machine_config.rsync_basepath / proc_file.gain_ref) - if proc_file.gain_ref and machine_config.data_transfer_enabled + if proc_file.gain_ref + and machine_config.data_transfer_enabled + and machine_config.rsync_basepath else proc_file.gain_ref ), "fm_int_file": proc_file.eer_fractionation_file, @@ -1299,7 +1305,7 @@ async def request_tomography_preprocessing( _transport_object.send("processing_recipe", zocalo_message) else: log.error( - f"Pe-processing was requested for {sanitise(ppath.name)} but no Zocalo transport object was found" + f"Preprocessing was requested for {sanitise(ppath.name)} but no Zocalo transport object was found" ) return proc_file else: @@ -1578,7 +1584,7 @@ async def process_gain( env = machine_config.external_environment safe_path_name = secure_filename(gain_reference_params.gain_ref.name) filepath = ( - Path(machine_config.rsync_basepath) + machine_config.rsync_basepath / (machine_config.rsync_module or "data") / str(datetime.datetime.now().year) / secure_filename(visit_name) @@ -1666,7 +1672,7 @@ async def write_eer_fractionation_file( ) / secure_filename(fractionation_params.fractionation_file_name) else: file_path = ( - Path(machine_config.rsync_basepath) + machine_config.rsync_basepath / (machine_config.rsync_module or "data") / str(datetime.datetime.now().year) / secure_filename(visit_name) @@ -1711,7 +1717,7 @@ async def make_gif( instrument_name ] output_dir = ( - Path(machine_config.rsync_basepath) + machine_config.rsync_basepath / (machine_config.rsync_module or "data") / secure_filename(year) / secure_filename(visit_name) diff --git a/src/murfey/server/api/auth.py b/src/murfey/server/api/auth.py index c962fa65f..66c1dc109 100644 --- a/src/murfey/server/api/auth.py +++ b/src/murfey/server/api/auth.py @@ -19,7 +19,7 @@ from murfey.server import sanitise from murfey.server.murfey_db import murfey_db, url -from murfey.util.config import get_machine_config, get_security_config +from murfey.util.config import get_global_config, get_machine_config from murfey.util.db import MurfeyUser as User from murfey.util.db import Session as MurfeySession @@ -63,19 +63,19 @@ async def __call__(self, request: Request): # Set up variables used for authentication -security_config = get_security_config() +global_config = get_global_config() machine_config = get_machine_config() auth_url = ( machine_config[os.getenv("BEAMLINE", "")].auth_url if machine_config.get(os.getenv("BEAMLINE", "")) else "" ) -ALGORITHM = security_config.auth_algorithm or "HS256" -SECRET_KEY = security_config.auth_key or secrets.token_hex(32) -if security_config.auth_type == "password": +ALGORITHM = global_config.auth_algorithm or "HS256" +SECRET_KEY = global_config.auth_key or secrets.token_hex(32) +if global_config.auth_type == "password": oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") else: - oauth2_scheme = CookieScheme(cookie_key=security_config.cookie_key) + oauth2_scheme = CookieScheme(cookie_key=global_config.cookie_key) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") instrument_server_tokens: Dict[float, dict] = {} @@ -96,7 +96,7 @@ def hash_password(password: str) -> str: # Set up database engine try: - _url = url(security_config) + _url = url(global_config) engine = create_engine(_url) except Exception: engine = None @@ -114,7 +114,7 @@ def validate_user(username: str, password: str) -> bool: def validate_visit(visit_name: str, token: str) -> bool: if validators := entry_points().select( group="murfey.auth.session_validation", - name=security_config.auth_type, + name=global_config.auth_type, ): return validators[0].load()(visit_name, token) return True @@ -166,12 +166,12 @@ async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]): if auth_url: headers = ( {} - if security_config.auth_type == "cookie" + if global_config.auth_type == "cookie" else {"Authorization": f"Bearer {token}"} ) cookies = ( - {security_config.cookie_key: token} - if security_config.auth_type == "cookie" + {global_config.cookie_key: token} + if global_config.auth_type == "cookie" else {} ) async with aiohttp.ClientSession(cookies=cookies) as session: @@ -186,7 +186,7 @@ async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]): else: if validators := entry_points().select( group="murfey.auth.token_validation", - name=security_config.auth_type, + name=global_config.auth_type, ): validators[0].load()(token) else: @@ -290,8 +290,8 @@ async def mint_session_token(session_id: MurfeySessionID, db=murfey_db): db.exec(select(MurfeySession).where(MurfeySession.id == session_id)).one().visit ) expiry_time = None - if security_config.session_token_timeout: - expiry_time = time.time() + security_config.session_token_timeout + if global_config.session_token_timeout: + expiry_time = time.time() + global_config.session_token_timeout token = create_access_token( { "session": session_id, diff --git a/src/murfey/server/api/spa.py b/src/murfey/server/api/spa.py index bc7594254..bfb889a65 100644 --- a/src/murfey/server/api/spa.py +++ b/src/murfey/server/api/spa.py @@ -20,6 +20,11 @@ def _cryolo_model_path(visit: str, instrument_name: str) -> Path: machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] + # Raise error if relevant keys weren't set in MachineConfig + if not machine_config.default_model: + raise ValueError("No default crYOLO model was set") + + # Find user-provided crYOLO model if machine_config.model_search_directory: visit_directory = ( machine_config.rsync_basepath @@ -32,10 +37,14 @@ def _cryolo_model_path(visit: str, instrument_name: str) -> Path: ) if possible_models: return sorted(possible_models, key=lambda x: x.stat().st_ctime)[-1] + + # Return default crYOLO model otherwise return machine_config.default_model @router.get("/sessions/{session_id}/cryolo_model") def get_cryolo_model_path(session_id: int, db=murfey_db): session = db.exec(select(MurfeySession).where(MurfeySession.id == session_id)).one() - return {"model_path": _cryolo_model_path(session.visit, session.instrment_name)} + return { + "model_path": str(_cryolo_model_path(session.visit, session.instrment_name)) + } diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index 5439d53c2..4f47c688d 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -40,7 +40,7 @@ from murfey.server.api import MurfeySessionID from murfey.server.api.auth import validate_token from murfey.server.murfey_db import murfey_db -from murfey.util.config import MachineConfig, from_file +from murfey.util.config import MachineConfig, machine_config_from_file from murfey.util.db import ( AutoProcProgram, ClientEnvironment, @@ -113,7 +113,9 @@ class Settings(BaseSettings): machine_config: dict[str, MachineConfig] = {} if settings.murfey_machine_configuration: microscope = get_microscope() - machine_config = from_file(Path(settings.murfey_machine_configuration), microscope) + machine_config = machine_config_from_file( + Path(settings.murfey_machine_configuration), microscope + ) # This will be the homepage for a given microscope. @@ -134,9 +136,9 @@ async def root(request: Request): def machine_info() -> Optional[MachineConfig]: instrument_name = os.getenv("BEAMLINE") if settings.murfey_machine_configuration and instrument_name: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] + return machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name + )[instrument_name] return None @@ -144,9 +146,9 @@ def machine_info() -> Optional[MachineConfig]: @router.get("/instruments/{instrument_name}/machine") def machine_info_by_name(instrument_name: str) -> Optional[MachineConfig]: if settings.murfey_machine_configuration: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] + return machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name + )[instrument_name] return None diff --git a/src/murfey/server/gain.py b/src/murfey/server/gain.py index f026be9a2..2109c0941 100644 --- a/src/murfey/server/gain.py +++ b/src/murfey/server/gain.py @@ -24,7 +24,7 @@ def _sanitise(gain_path: Path) -> Path: async def prepare_gain( camera: int, gain_path: Path, - executables: Dict[str, str], + executables: Dict[str, Path], env: Dict[str, str], rescale: bool = True, tag: str = "", @@ -57,7 +57,7 @@ async def prepare_gain( gain_path_mrc = gain_path.with_suffix(".mrc") gain_path_superres = gain_path.parent / (gain_path.name + "_superres.mrc") dm4_proc = await asyncio.create_subprocess_shell( - f"{executables['dm2mrc']} {gain_path} {gain_path_mrc}", + f"{str(executables['dm2mrc'])} {gain_path} {gain_path_mrc}", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -65,7 +65,7 @@ async def prepare_gain( if dm4_proc.returncode: return None, None clip_proc = await asyncio.create_subprocess_shell( - f"{executables['clip']} {flip} {secure_path(gain_path_mrc)} {secure_path(gain_path_superres) if rescale else secure_path(gain_out)}", + f"{str(executables['clip'])} {flip} {secure_path(gain_path_mrc)} {secure_path(gain_path_superres) if rescale else secure_path(gain_out)}", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -74,7 +74,7 @@ async def prepare_gain( return None, None if rescale: newstack_proc = await asyncio.create_subprocess_shell( - f"{executables['newstack']} -bin 2 {secure_path(gain_path_superres)} {secure_path(gain_out)}", + f"{str(executables['newstack'])} -bin 2 {secure_path(gain_path_superres)} {secure_path(gain_out)}", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -88,7 +88,7 @@ async def prepare_gain( async def prepare_eer_gain( - gain_path: Path, executables: Dict[str, str], env: Dict[str, str], tag: str = "" + gain_path: Path, executables: Dict[str, Path], env: Dict[str, str], tag: str = "" ) -> Tuple[Path | None, Path | None]: if not executables.get("tif2mrc"): return None, None @@ -98,7 +98,7 @@ async def prepare_eer_gain( for k, v in env.items(): os.environ[k] = v mrc_convert = await asyncio.create_subprocess_shell( - f"{executables['tif2mrc']} {secure_path(gain_path)} {secure_path(gain_out)}" + f"{str(executables['tif2mrc'])} {secure_path(gain_path)} {secure_path(gain_out)}" ) await mrc_convert.communicate() if mrc_convert.returncode: diff --git a/src/murfey/server/main.py b/src/murfey/server/main.py index 96533fd7d..ec99d5598 100644 --- a/src/murfey/server/main.py +++ b/src/murfey/server/main.py @@ -21,7 +21,7 @@ import murfey.server.websocket import murfey.util.models from murfey.server import template_files -from murfey.util.config import get_security_config +from murfey.util.config import get_global_config # Import Murfey server or demo server based on settings if os.getenv("MURFEY_DEMO"): @@ -39,7 +39,7 @@ class Settings(BaseSettings): murfey_machine_configuration: str = "" -security_config = get_security_config() +global_config = get_global_config() settings = Settings() @@ -50,7 +50,7 @@ class Settings(BaseSettings): app.add_middleware( CORSMiddleware, - allow_origins=security_config.allow_origins, + allow_origins=global_config.allow_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/src/murfey/server/murfey_db.py b/src/murfey/server/murfey_db.py index 2d0d52cf7..cc3085c8c 100644 --- a/src/murfey/server/murfey_db.py +++ b/src/murfey/server/murfey_db.py @@ -8,23 +8,27 @@ from sqlalchemy.pool import NullPool from sqlmodel import Session, create_engine -from murfey.util.config import Security, get_security_config +from murfey.util.config import GlobalConfig, get_global_config -def url(security_config: Security | None = None) -> str: - security_config = security_config or get_security_config() - with open(security_config.murfey_db_credentials, "r") as stream: +def url(global_config: GlobalConfig | None = None) -> str: + global_config = global_config or get_global_config() + if global_config.murfey_db_credentials is None: + raise ValueError( + "No database credentials file was provided for this instance of Murfey" + ) + with open(global_config.murfey_db_credentials, "r") as stream: creds = yaml.safe_load(stream) - f = Fernet(security_config.crypto_key.encode("ascii")) + f = Fernet(global_config.crypto_key.encode("ascii")) p = f.decrypt(creds["password"].encode("ascii")) return f"postgresql+psycopg2://{creds['username']}:{p.decode()}@{creds['host']}:{creds['port']}/{creds['database']}" def get_murfey_db_session( - security_config: Security | None = None, + global_config: GlobalConfig | None = None, ) -> Session: # type: ignore - _url = url(security_config) - if security_config and not security_config.sqlalchemy_pooling: + _url = url(global_config) + if global_config and not global_config.sqlalchemy_pooling: engine = create_engine(_url, poolclass=NullPool) else: engine = create_engine(_url) @@ -37,7 +41,7 @@ def get_murfey_db_session( murfey_db_session = partial( get_murfey_db_session, - get_security_config(), + get_global_config(), ) murfey_db: Session = Depends(murfey_db_session) diff --git a/src/murfey/util/config.py b/src/murfey/util/config.py index 00750c806..3e53870ae 100644 --- a/src/murfey/util/config.py +++ b/src/murfey/util/config.py @@ -4,75 +4,483 @@ import socket from functools import lru_cache from pathlib import Path -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Literal, Mapping, Optional, Union import yaml from backports.entry_points_selectable import entry_points -from pydantic import BaseModel, BaseSettings, Extra, validator - - -class MachineConfig(BaseModel, extra=Extra.allow): # type: ignore - acquisition_software: List[str] - calibrations: Dict[str, Dict[str, Union[dict, float]]] - data_directories: List[Path] - rsync_basepath: Path - default_model: Path - display_name: str = "" - instrument_name: str = "" - image_path: Optional[Path] = None - software_versions: Dict[str, str] = {} - external_executables: Dict[str, str] = {} - external_executables_eer: Dict[str, str] = {} - external_environment: Dict[str, str] = {} - rsync_module: str = "" - create_directories: Dict[str, str] = {"atlas": "atlas"} - analyse_created_directories: List[str] = [] - gain_reference_directory: Optional[Path] = None - eer_fractionation_file_template: str = "" - processed_directory_name: str = "processed" - gain_directory_name: str = "processing" - node_creator_queue: str = "node_creator" - superres: bool = False - camera: str = "FALCON" - data_required_substrings: Dict[str, Dict[str, List[str]]] = {} - allow_removal: bool = False - modular_spa: bool = False - data_transfer_enabled: bool = True - processing_enabled: bool = True - machine_override: str = "" - processed_extra_directory: str = "" - plugin_packages: Dict[str, Path] = {} - software_settings_output_directories: Dict[str, List[str]] = {} - process_by_default: bool = True - recipes: Dict[str, str] = { - "em-spa-bfactor": "em-spa-bfactor", - "em-spa-class2d": "em-spa-class2d", - "em-spa-class3d": "em-spa-class3d", - "em-spa-preprocess": "em-spa-preprocess", - "em-spa-refine": "em-spa-refine", - "em-tomo-preprocess": "em-tomo-preprocess", - "em-tomo-align": "em-tomo-align", - } +from pydantic import ( + BaseConfig, + BaseModel, + BaseSettings, + Extra, + Field, + root_validator, + validator, +) +from pydantic.errors import NoneIsNotAllowedError + + +class MachineConfig(BaseModel): + """ + General information about the instrument being supported + """ - # Find and download upstream directories - upstream_data_directories: List[Path] = [] # Previous sessions - upstream_data_download_directory: Optional[Path] = None # Set by microscope config - upstream_data_tiff_locations: List[str] = ["processed"] # Location of CLEM TIFFs + display_name: str = Field( + default="", + description="Name of instrument used for display purposes, i.e. Krios I.", + ) + instrument_name: str = Field( + default="", + description=( + "Computer-friendly instrument reference name, i.e. m02. " + "The name must not contain special characters or whitespace." + ), + ) + image_path: Optional[Path] = Field( + default=None, + description="Path to an image of the instrument for display purposes.", + ) + machine_override: str = Field( + default="", + description=( + "Override the instrument name as defined in the environment variable or " + "the configuration with this one. This is used if, for example, many " + "machines are sharing a server, and need to be named differently." + ), + ) - model_search_directory: str = "processing" - initial_model_search_directory: str = "processing/initial_model" + """ + Information about the hardware and software on the instrument machine + """ + camera: Literal["FALCON", "K3_FLIPX", "K3_FLIPY", ""] = Field( + default="", + description=( + "Name of the camera used by the TEM. This is only relevant for TEMs to " + "determine how the gain reference needs to be processed, e.g., if it has " + "to be binned down from superres or flipped along the x- or y-axis. " + "Options: 'FALCON', 'K3_FLIPX', 'K3_FLIPY', ''" + ), + # NOTE: + # Eventually need to support Falcon 4, Falcon 4I, K2, K3 (superres) + # _FLIPX/_FLIPY is to tell it what to do with the gain reference. + # - These will eventually be removed, leaving only the camera name + # - Will need to create a new key to record whether the gain reference + # image needs to be flippedflip_gain: X, Y, None + ) + superres: bool = Field( + default=False, + description=( + "Check if the superres feature present on this microscope? " + "For a Gatan K3, this will be set to True." + ), + ) + flip_gain: Literal["x", "y", ""] = Field( + default="", + description=( + "State if the gain reference needs to be flipped along a specific axis. " + "Options: 'x', 'y', or ''." + ), + # NOTE: This is a placeholder for a key that will be implemented in the future + ) + calibrations: dict[str, dict[str, Union[dict, float]]] = Field( + default={}, + description=( + "Nested dictionary containing the calibrations for this microscope. " + "E.g., 'magnification' would be a valid dictionary, in which the " + "pixel size (in angstroms) at each magnfication level is provided as a " + "key-value pair. Options: 'magnification'" + ), + ) - failure_queue: str = "" - instrument_server_url: str = "http://localhost:8001" - frontend_url: str = "http://localhost:3000" - murfey_url: str = "http://localhost:8000" - rsync_url: str = "" + # NOTE: + # acquisition_software, software_versions, and software_settings_output_directories + # can all potentially be combined into one nested dictionary + acquisition_software: list[ + Literal["epu", "tomo", "serialem", "autotem", "leica"] + ] = Field( + default=[], + description=("List of all the acquisition software present on this machine."), + ) + software_versions: dict[str, str] = Field( + default={}, + description=( + "Dictionary containing the version number of the acquisition software as " + "key-value pairs." + ), + ) + software_settings_output_directories: dict[str, list[str]] = Field( + default={}, + description=( + "A dictionary in which the keys are the full file paths to the settings " + "for the acquisition software packages, and the values are lists of keys " + "through the layered structure of the XML settings files to where the save " + "directory can be overwritten." + ), + ) - security_configuration_path: Optional[Path] = None - auth_url: str = "" + # Instrument-side file paths + data_required_substrings: dict[str, dict[str, list[str]]] = Field( + default={}, + description=( + "Nested dictionary stating the file suffixes to look for as part of the " + "processing workflow for a given software package, and subsequently the " + "key phrases to search for within the file name for it to be selected for " + "processing." + ), + ) + data_directories: list[Path] = Field( + default=[], + description=( + "List of full paths to where data is stored on the instrument machine." + ), + ) + create_directories: dict[str, str] = Field( + default={"atlas": "atlas"}, + description=( + "Dictionary describing the directories to create within each visit on the " + "instrument machine. The key will be what Murfey calls the folder internaly, " + "while the value is what the folder is actually called on the file system." + ), + # NOTE: This key should eventually be changed into a list of strings + ) + analyse_created_directories: list[str] = Field( + default=[], + description=( + "List of folders to be considered for analysis by Murfey. This will " + "generally be a subset of the list of folders specified earlier when " + "creating the directories for each visit." + ), + ) + gain_reference_directory: Optional[Path] = Field( + default=None, + description=( + "Full path to where the gain reference from the detector is saved." + ), + ) + eer_fractionation_file_template: str = Field( + default="", + description=( + "File path template that can be provided if the EER fractionation files " + "are saved in a location separate from the rest of the data. This will " + "be a string, with '{visit}' and '{year}' being optional arguments that " + "can be embedded in the string. E.g.: '/home/user/data/{year}/{visit}'" + ), + # Only if Falcon is used + # To avoid others having to follow the {year}/{visit} format we are doing + ) + """ + Data transfer-related settings + """ + # rsync-related settings (only if rsync is used) + data_transfer_enabled: bool = Field( + default=False, + description=("Toggle whether to enable data transfer via rsync."), + # NOTE: Only request input for this code block if data transfer is enabled + ) + allow_removal: bool = Field( + default=False, description="Allow original files to be removed after rsync." + ) + rsync_basepath: Path = Field( + default=Path("/"), + description=( + "Full path on the storage server that the rsync daemon will append the " + "relative paths of the transferred files to." + ), + # If rsync is disabled, rsync_basepath works out to be "/". + # Must always be set. + ) + rsync_module: str = Field( + default="", + description=( + "Name of the rsync module the files are being transferred with. The module " + "will be appended to the rsync base path, and the relative paths will be " + "appended to the module. This is particularly useful when many instrument " + "machines are transferring to the same storage server, as you can specify " + "different sub-folders to save the data to." + ), + ) + rsync_url: str = Field( + default="", + description=( + "URL to a remote rsync daemon. By default, the rsync daemon will be " + "running on the client machine, and this defaults to an empty string." + ), + ) + + # Related visits and data + upstream_data_directories: list[Path] = Field( + default=[], + description=( + "List of full paths to folders on other machines for Murfey to look for the " + "current visit in. This is primarily used for multi-instrument workflows " + "that use processed data from other instruments as input." + ), + ) + upstream_data_download_directory: Optional[Path] = Field( + default=None, + description=( + "Path to the folder on this instrument machine to transfer files from other " + "machines to." + ), + ) + upstream_data_tiff_locations: list[str] = Field( + default=["processed"], + description=( + "Name of the sub-folder within the visit folder from which to transfer the " + "results. This would typically be the 'processed' folder." + ), + # NOTE: This should eventually be converted into a dictionary, which looks for + # files in different locations according to the workflows they correspond to + ) + + """ + Data processing-related settings + """ + # Processing-related keys + processing_enabled: bool = Field( + default=False, + description="Toggle whether to enable data processing.", + # NOTE: Only request input for this code block if processing is enabled + ) + process_by_default: bool = Field( + default=True, + description=( + "Toggle whether processing should be enabled by default. If False, Murfey " + "will ask the user whether they want to process the data in their current " + "session." + ), + ) + + # Server-side file paths + gain_directory_name: str = Field( + default="processing", + description=( + "Name of the folder to save the files used to facilitate data processing to. " + "This folder will be located under the current visit." + ), + ) + processed_directory_name: str = Field( + default="processed", + description=( + "Name of the folder to save the output of the data processing workflow to. " + "This folder will be located under the current visit." + ), + ) + processed_extra_directory: str = Field( + default="", + description=( + "Name of the sub-folder in the processed directory to save the output of " + "additional processing workflows to. E.g., if you are using Relion for " + "processing, its output files could be stored in a 'relion' sub-folder." + ), + # NOTE: This should eventually be a list of strings, if we want to allow + # users to add more processing options to their workflow + ) + + # TEM-related processing workflows + recipes: dict[ + Literal[ + "em-spa-bfactor", + "em-spa-class2d", + "em-spa-class3d", + "em-spa-preprocess", + "em-spa-refine", + "em-tomo-preprocess", + "em-tomo-align", + ], + str, + ] = Field( + default={ + "em-spa-bfactor": "em-spa-bfactor", + "em-spa-class2d": "em-spa-class2d", + "em-spa-class3d": "em-spa-class3d", + "em-spa-preprocess": "em-spa-preprocess", + "em-spa-refine": "em-spa-refine", + "em-tomo-preprocess": "em-tomo-preprocess", + "em-tomo-align": "em-tomo-align", + }, + description=( + "A dictionary of recipes for Murfey to run to facilitate data processing. " + "The key represents the name of the recipe used by Murfey, while its value " + "is the name of the recipe in the repository it's in." + ), + # NOTE: Currently, this recipe-searching structure is tied to the GitLab repo; + # need to provide an option to map it file paths instead, or even a folder. + # A parameter like recipe_folder might work? + ) + modular_spa: bool = Field( + default=True, + description=( + "Deprecated key to toggle SPA processing; will be phased out eventually." + ), + ) -def from_file(config_file_path: Path, instrument: str = "") -> Dict[str, MachineConfig]: + # Particle picking settings + default_model: Optional[Path] = Field( + default=None, + description=( + "Path to the default machine learning model used for particle picking." + ), + ) + model_search_directory: str = Field( + default="processing", + description=( + "Relative path to where user-uploaded machine learning models are stored. " + "Murfey will look for the folders under the current visit." + ), + ) + initial_model_search_directory: str = Field( + default="processing/initial_model", # User-uploaded electron density models + description=( + "Relative path to where user-uploaded electron density models are stored. " + "Murfey will look for the folders under the current visit." + ), + ) + + # Extra plugins for data acquisition(?) + external_executables: dict[str, Path] = Field( + default={}, + description=( + "Dictionary containing additional software packages to be used as part of " + "the processing workflow. The keys are the names of the packages and the " + "values are the full paths to where the executables are located." + ), + ) + external_executables_eer: dict[str, Path] = Field( + default={}, + description=( + "A similar dictionary, but for the executables associated with processing " + "EER files." + ), + # NOTE: Both external_executables variables should be combined into one. The + # EER ones could be their own key, where different software packages are + # provided for different file types in different workflows. + ) + external_environment: dict[str, str] = Field( + default={}, + description=( + "Dictionary containing full paths to folders containing the supporting " + "software needed to run the executables to be used. These paths will be " + "appended to the $PATH environment variable, so if multiple paths are " + "associated with a single executable, they need to be provided as colon-" + "separated strings. E.g. '/this/is/one/folder:/this/is/another/one'" + ), + ) + plugin_packages: dict[str, Path] = Field( + default={}, + description=( + "Dictionary containing full paths to additional plugins for Murfey that " + "help support the data collection and processing workflow." + ), + ) + + """ + Server and network-related configurations + """ + # Security-related keys + global_configuration_path: Optional[Path] = Field( + description=( + "Full file path to the YAML file containing the configurations for the " + "Murfey server." + ), + alias="security_configuration_path", + ) + # Network connections + frontend_url: str = Field( + default="http://localhost:3000", + description="URL to the Murfey frontend.", + ) + murfey_url: str = Field( + default="http://localhost:8000", + description="URL to the Murfey API.", + ) + instrument_server_url: str = Field( + default="http://localhost:8001", + description="URL to the instrument server.", + ) + auth_url: str = Field( + default="", + description="URL to where users can authenticate their Murfey sessions.", + ) + + # RabbitMQ-specific keys + failure_queue: str = Field( + default="", + description="Name of RabbitMQ queue where failed API calls will be recorded.", + ) + node_creator_queue: str = Field( + default="node_creator", + description=( + "Name of the RabbitMQ queue where requests for creating job nodes are sent." + ), + ) + + class Config(BaseConfig): + """ + Additional settings for how this Pydantic model behaves + """ + + extra = Extra.allow + json_encoders = {Path: str} + + @validator("camera", always=True, pre=True) + def __validate_camera_model__(cls, value: str): + # Let non-strings fail validation naturally + if not isinstance(value, str): + return value + # Handle empty string + if len(value) == 0: + return value + # Match string to known camera models + supported_camera_models = ("FALCON", "K3") + if value.upper().startswith( + supported_camera_models + ): # Case-insensitive matching + return value.upper() + else: + raise ValueError( + f"unexpected value; permitted: {supported_camera_models!r} " + f"(type=value_error.const; given={value!r}; " + f"permitted={supported_camera_models!r})" + ) + + @root_validator(pre=False) + def __validate_superres__(cls, model: dict): + camera: str = model.get("camera", "") + model["superres"] = True if camera.startswith("K3") else False + return model + + @validator("rsync_basepath", always=True) + def __validate_rsync_basepath_if_transfer_enabled__( + cls, v: Optional[str], values: Mapping[str, Any] + ) -> Any: + """ + If data transfer is enabled, an rsync basepath must be provided. + """ + if values.get("data_transfer_enabled"): + if v is None: + raise NoneIsNotAllowedError + return v + + @validator("default_model", always=True) + def __validate_default_model_if_processing_enabled_and_spa_possible__( + cls, v: Optional[str], values: Mapping[str, Any] + ) -> Any: + """ + If data processing is enabled, a machine learning model must be provided. + """ + if values.get("processing_enabled") and "epu" in values.get( + "acquisition_software", [] + ): + if v is None: + raise NoneIsNotAllowedError + return v + + +def machine_config_from_file( + config_file_path: Path, instrument: str = "" +) -> dict[str, MachineConfig]: with open(config_file_path, "r") as config_stream: config = yaml.safe_load(config_stream) return { @@ -82,19 +490,61 @@ def from_file(config_file_path: Path, instrument: str = "") -> Dict[str, Machine } -class Security(BaseModel): - rabbitmq_credentials: str - murfey_db_credentials: str - crypto_key: str +class GlobalConfig(BaseModel): + # Database connection settings + murfey_db_credentials: Optional[Path] = Field( + description=( + "Full file path to where Murfey's SQL database credentials are stored. " + "This is typically a YAML file." + ), + ) + sqlalchemy_pooling: bool = Field( + default=True, + description=( + "Toggles connection pooling functionality in the SQL database. If 'True', " + "clients will connect to the database using an existing pool of connections " + "instead of creating a new one every time." + ), + ) + crypto_key: str = Field( + default="", + description=( + "The encryption key used for the SQL database. This can be generated by " + "Murfey using the 'murfey.generate_key' command." + ), + ) + + # RabbitMQ settings + rabbitmq_credentials: Optional[Path] + feedback_queue: str = Field( + default="murfey_feedback", + description=( + "The name of the RabbitMQ queue that will receive instructions and " + "the results of processing jobs on behalf of Murfey. This queue can be " + "by multiple server instances, which is why it's stored here instead of " + "in the machine configuration." + ), + ) + + # Server authentication settings + auth_type: Literal["password", "cookie"] = Field( + default="password", + description=( + "Choose how Murfey will authenticate new connections that it receives. " + "This can be done at present via password authentication or exchanging " + "cookies." + ), + ) auth_key: str = "" auth_algorithm: str = "" - sqlalchemy_pooling: bool = True - allow_origins: List[str] = ["*"] - session_validation: str = "" - session_token_timeout: Optional[int] = None - auth_type: Literal["password", "cookie"] = "password" cookie_key: str = "" - feedback_queue: str = "murfey_feedback" + session_validation: str = "" + session_token_timeout: Optional[int] = ( + None # seconds; typically the length of a microscope session plus a bit + ) + allow_origins: list[str] = ["*"] # Restrict to only certain hostnames + + # Graylog settings graylog_host: str = "" graylog_port: Optional[int] = None @@ -107,15 +557,15 @@ def check_port_present_if_host_is( return v -def security_from_file(config_file_path: Path) -> Security: +def global_config_from_file(config_file_path: Path) -> GlobalConfig: with open(config_file_path, "r") as config_stream: config = yaml.safe_load(config_stream) - return Security(**config) + return GlobalConfig(**config) class Settings(BaseSettings): murfey_machine_configuration: str = "" - murfey_security_configuration: str = "" + murfey_global_configuration: str = "" settings = Settings() @@ -126,6 +576,8 @@ def get_hostname(): return socket.gethostname() +# How does microscope_name differ from instrument_name? +# Should we stick to one? def get_microscope(machine_config: MachineConfig | None = None) -> str: if machine_config: microscope_name = machine_config.machine_override or os.getenv("BEAMLINE", "") @@ -135,19 +587,20 @@ def get_microscope(machine_config: MachineConfig | None = None) -> str: @lru_cache(maxsize=1) -def get_security_config() -> Security: - if settings.murfey_security_configuration: - return security_from_file(Path(settings.murfey_security_configuration)) +def get_global_config() -> GlobalConfig: + if settings.murfey_global_configuration: + return global_config_from_file(Path(settings.murfey_global_configuration)) if settings.murfey_machine_configuration and os.getenv("BEAMLINE"): machine_config = get_machine_config(instrument_name=os.getenv("BEAMLINE"))[ os.getenv("BEAMLINE", "") ] - if machine_config.security_configuration_path: - return security_from_file(machine_config.security_configuration_path) - return Security( - rabbitmq_credentials="", + if not machine_config.global_configuration_path: + raise FileNotFoundError("No global configuration file provided") + return global_config_from_file(machine_config.global_configuration_path) + return GlobalConfig( + rabbitmq_credentials=None, session_validation="", - murfey_db_credentials="", + murfey_db_credentials=None, crypto_key="", auth_key="", auth_algorithm="", @@ -156,20 +609,19 @@ def get_security_config() -> Security: @lru_cache(maxsize=1) -def get_machine_config(instrument_name: str = "") -> Dict[str, MachineConfig]: +def get_machine_config(instrument_name: str = "") -> dict[str, MachineConfig]: machine_config = { "": MachineConfig( acquisition_software=[], calibrations={}, data_directories=[], rsync_basepath=Path("dls/tmp"), - murfey_db_credentials="", default_model="/tmp/weights.h5", ) } if settings.murfey_machine_configuration: microscope = instrument_name - machine_config = from_file( + machine_config = machine_config_from_file( Path(settings.murfey_machine_configuration), microscope ) return machine_config diff --git a/tests/cli/test_decrypt_password.py b/tests/cli/test_decrypt_password.py index 65952e5e8..703823d36 100644 --- a/tests/cli/test_decrypt_password.py +++ b/tests/cli/test_decrypt_password.py @@ -5,16 +5,16 @@ from cryptography.fernet import Fernet from murfey.cli.decrypt_db_password import run -from murfey.util.config import get_security_config +from murfey.util.config import get_global_config def test_decrypt_password(capsys, tmp_path): - security_config = get_security_config() + global_config = get_global_config() crypto_key = Fernet.generate_key() - security_config.crypto_key = crypto_key.decode("ascii") + global_config.crypto_key = crypto_key.decode("ascii") with open(tmp_path / "config.yaml", "w") as cfg: - yaml.dump(security_config.dict(), cfg) - os.environ["MURFEY_SECURITY_CONFIGURATION"] = str(tmp_path / "config.yaml") + yaml.dump(global_config.dict(), cfg) + os.environ["MURFEY_GLOBAL_CONFIGURATION"] = str(tmp_path / "config.yaml") password = "abcd" f = Fernet(crypto_key) encrypted_password = f.encrypt(password.encode("ascii")).decode() diff --git a/tests/cli/test_generate_password.py b/tests/cli/test_generate_password.py index fa48e9cf2..43cfa00d6 100644 --- a/tests/cli/test_generate_password.py +++ b/tests/cli/test_generate_password.py @@ -4,16 +4,16 @@ from cryptography.fernet import Fernet from murfey.cli.generate_db_password import run -from murfey.util.config import get_security_config +from murfey.util.config import get_global_config def test_generate_password(capsys, tmp_path): - security_config = get_security_config() + global_config = get_global_config() crypto_key = Fernet.generate_key() - security_config.crypto_key = crypto_key.decode("ascii") + global_config.crypto_key = crypto_key.decode("ascii") with open(tmp_path / "config.yaml", "w") as cfg: - yaml.dump(security_config.dict(), cfg) - os.environ["MURFEY_SECURITY_CONFIGURATION"] = str(tmp_path / "config.yaml") + yaml.dump(global_config.dict(), cfg) + os.environ["MURFEY_GLOBAL_CONFIGURATION"] = str(tmp_path / "config.yaml") run() captured = capsys.readouterr() f = Fernet(crypto_key)