diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..0c8b06b --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,35 @@ +name: Python package test + +on: + push: + branches: + - master + pull_request: + workflow_dispatch: + +jobs: + test: + strategy: + matrix: + system: [["3.11", "ubuntu-latest"], ["3.12", "macos-latest"], ["3.12", "ubuntu-latest"]] + runs-on: ${{ matrix.system[1] }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.system[0] }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.system[0] }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + pip install --editable .[test] + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest diff --git a/.gitignore b/.gitignore index 83833c4..f10d00e 100644 --- a/.gitignore +++ b/.gitignore @@ -111,4 +111,7 @@ IM2Deep.code-workspace # Testing test_data/ -test.ipynb \ No newline at end of file +test.ipynb + +# Profiles +profiles/ \ No newline at end of file diff --git a/README.md b/README.md index 2f47265..b927b98 100644 --- a/README.md +++ b/README.md @@ -4,45 +4,93 @@ Collisional cross-section prediction for (modified) peptides. --- ## Introduction -IM2Deep is a CCS predictor for (modified) peptides. -It is able to accurately predict CCS for modified peptides, even if the modification wasn't observed during training. +IM2Deep is a deep learning-based CCS predictor for (modified) peptides. It accurately predicts collisional cross-section (CCS) values for modified peptides, even if the modification wasn't observed during training. The tool supports both single-conformer and multi-conformer predictions for peptide ions. ## Installation Install with pip: -`pip install im2deep` +```bash +pip install im2deep +``` -If you want to use the multi-output model for CCS prediction of multiconformational peptide ions, use the following installation command: +## Usage -`pip install 'im2deep[er]'` +### Command Line Interface (CLI) -## Usage -### Basic CLI usage: -```sh +**Basic prediction:** +```bash im2deep ``` -If you want to calibrate your predictions (HIGHLY recommended), please provide a calibration file: -```sh -im2deep --calibration-file + +**With calibration (HIGHLY recommended):** +```bash +im2deep --calibration-precursors +``` + +**Calibration options:** +- `--calibrate-per-charge`: Calculate separate calibration shift factors per charge state (recommended, default true) +- `--use-charge-state`: Charge state for global calibration when --calibrate-per-charge is disabled + +**Multi-conformer prediction:** +To use the multi-output prediction model (requires optional dependencies): +```bash +im2deep --calibration-precursors --multi +``` + +**Output options:** +```bash +im2deep --output-file predictions.csv ``` -To use the multi-output prediction model on top of the original model, provide the -e flag -(make sure you have the optional dependencies installed!): -```sh -im2deep --calibration-file -e + +For a complete overview of all CLI arguments, run: +```bash +im2deep --help ``` -For an overview of all CLI arguments, run `im2deep --help`. -## Input files -Both peptide and calibration files are expected to be comma-separated values (CSV) with the following columns: - - `seq`: unmodified peptide sequence - - `modifications`: every modifications should be listed as `location|name`, separated by a pipe character (`|`) - between the location, the name, and other modifications. `location` is an integer counted starting at 1 for the - first AA. 0 is reserved for N-terminal modifications, -1 for C-terminal modifications. `name` has to correspond - to a Unimod (PSI-MS) name. - - `charge`: peptide precursor charge - - `CCS`: collisional cross-section (only for calibration file) +### Python API + +IM2Deep can also be used programmatically: + +```python +from im2deep import predict, predict_and_calibrate +from psm_utils import PSMList -For example: +# Load your peptides as PSMList +psm_list = PSMList(psm_list=[...]) # or use psm_utils.io.read_file() + +# Simple prediction +predictions = predict(psm_list) + +# Prediction with calibration +psm_list_calibration = PSMList(psm_list=[...]) # Must contain CCS values +calibrated_predictions = predict_and_calibrate( + psm_list=psm_list, + psm_list_cal=psm_list_calibration +) +``` + +## Input Files + +### Standard Format +IM2Deep accepts any format supported by [psm_utils](https://github.com/compomics/psm_utils), including: +- Peptide Record (.peprec) +- MaxQuant msms.txt +- MSFragger PSM files +- And more... + +### Legacy CSV Format +Alternatively, use comma-separated values (CSV) with the following columns: + +- **`seq`**: Unmodified peptide sequence +- **`modifications`**: Modifications listed as `location|name`, separated by pipe (`|`) characters + - `location`: Integer starting at 1 for the first amino acid + - `0` = N-terminal modification + - `-1` = C-terminal modification + - `name`: Must correspond to a Unimod (PSI-MS) name +- **`charge`**: Peptide precursor charge state +- **`CCS`**: Collisional cross-section (only required for calibration files) + +**Example:** ```csv seq,modifications,charge,CCS @@ -54,6 +102,11 @@ DEELIHLDGK,,2,383.8693416055445 IPQEKCILQTDVK,5|Butyryl|6|Carbamidomethyl,3,516.2079366048176 ``` +## Important Notes + +- **Calibration**: Highly recommended for accurate predictions. Calibration corrects for systematic differences between predicted and observed CCS values. +- **Charge states**: IM2Deep predictions are reliable for charge states up to z=6. PSMs with higher charge states are automatically filtered out during validation. + ## Citing If you use IM2Deep within the context of [(TI)MS2Rescore](https://github.com/compomics/ms2rescore), please cite the following: > **TIMS²Rescore: A DDA-PASEF optimized data-driven rescoring pipeline based on MS²Rescore.** @@ -66,4 +119,3 @@ In other cases, please cite the following: > _Anal. Chem._ (2025) [doi:10.1021/acs.analchem.5c01142](https://pubs.acs.org/doi/10.1021/acs.analchem.5c01142) - diff --git a/im2deep/__init__.py b/im2deep/__init__.py index 2224ab7..dbc4298 100644 --- a/im2deep/__init__.py +++ b/im2deep/__init__.py @@ -39,16 +39,13 @@ __version__ = "1.2.0" # Import main functionality for easier access -from im2deep.im2deep import predict_ccs -from im2deep.calibrate import linear_calibration +from importlib.metadata import version from im2deep.utils import ccs2im, im2ccs -from im2deep._exceptions import IM2DeepError, CalibrationError +__version__: str = version("im2deep") __all__ = [ - "predict_ccs", - "linear_calibration", + "predict", + "calibrate_and_predict", "ccs2im", "im2ccs", - "IM2DeepError", - "CalibrationError", ] diff --git a/im2deep/__main__.py b/im2deep/__main__.py index e3b23d5..9e237fd 100644 --- a/im2deep/__main__.py +++ b/im2deep/__main__.py @@ -1,4 +1,5 @@ """ +# TODO: update docstring Command line interface for IM2Deep. This module provides a comprehensive command-line interface for the IM2Deep @@ -41,199 +42,82 @@ from __future__ import annotations import logging -import sys -from pathlib import Path -from typing import Optional +import cProfile import click -import pandas as pd - -from psm_utils.io import read_file -from psm_utils.io.exceptions import PSMUtilsIOException -from psm_utils.io.peptide_record import peprec_to_proforma -from psm_utils.psm import PSM -from psm_utils.psm_list import PSMList -from rich.logging import RichHandler +from rich.console import Console -REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "reference_ccs.zip" +from im2deep import __version__, core +from pathlib import Path +from im2deep.utils import ( + setup_logging, + parse_input, + build_credits, + write_output, + infer_output_name, + DefaultCommandGroup, +) +console = Console() LOGGER = logging.getLogger(__name__) -def setup_logging(passed_level: str) -> None: - """ - Configure logging with Rich formatting. - - Parameters - ---------- - passed_level : str - Logging level name (debug, info, warning, error, critical) - - Raises - ------ - ValueError - If invalid logging level provided - """ - log_mapping = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, - } - - if passed_level.lower() not in log_mapping: - raise ValueError( - f"Invalid log level: {passed_level}. " f"Should be one of {list(log_mapping.keys())}" - ) - - logging.basicConfig( - level=log_mapping[passed_level.lower()], - format="%(message)s", - datefmt="[%X]", - handlers=[RichHandler()], - ) - - -def check_optional_dependencies() -> None: - """ - Check if optional dependencies for multi-conformer prediction are available. +# Command line interface +@click.group(cls=DefaultCommandGroup, default_command="predict", invoke_without_command=True) +@click.pass_context +@click.option( + "--logging-level", + "-l", + type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False), + default="info", + help="Set logging verbosity level.", +) +@click.option( + "--profile", + is_flag=True, + default=False, + help="Enable profiling with cProfile. Results saved to 'im2deep_profile.prof'.", +) +@click.option( + "--profile-name", + type=click.Path(dir_okay=False), + default="im2deep_profile.prof", + help="Output file name for cProfile results when --profile is enabled.", +) +@click.version_option(version=__version__) +def cli(ctx, logging_level, profile, profile_name): + """IM2Deep: Predict CCS values for peptides using deep learning. - Raises - ------ - SystemExit - If required dependencies are missing - """ - try: - import torch - import im2deeptrainer - - LOGGER.debug("Optional dependencies for multi-conformer prediction found") - except ImportError: - LOGGER.error( - "Multi-conformer prediction requires optional dependencies.\n" - "Please install IM2Deep with optional dependencies:\n" - "pip install 'im2deep[er]'" - ) - sys.exit(1) + Run prediction with: im2deep INPUT_FILE [OPTIONS] + With calibration: im2deep INPUT_FILE -c CALIBRATION_FILE -def _validate_file_format(file_path: str, file_type: str = "input") -> bool: + Use subcommands for additional functionality: + im2deep train ... """ - Validate file format and accessibility. - - Parameters - ---------- - file_path : str - Path to file to validate - file_type : str - Type of file for error messages - - Returns - ------- - bool - True if file is valid - - Raises - ------ - click.ClickException - If file validation fails - """ - path = Path(file_path) - - if not path.exists(): - raise click.ClickException(f"{file_type.capitalize()} file not found: {file_path}") + setup_logging(logging_level) - if not path.is_file(): - raise click.ClickException(f"{file_type.capitalize()} path is not a file: {file_path}") + # Store parameters in context for subcommands + ctx.ensure_object(dict) + ctx.obj["logging_level"] = logging_level + ctx.obj["profile"] = profile + ctx.obj["profile_name"] = profile_name - if path.suffix.lower() not in [".csv", ".txt", ".tsv"]: - LOGGER.warning(f"Unexpected file extension for {file_type} file: {path.suffix}") + console.print(build_credits()) - try: - with open(file_path, "r", encoding="utf-8") as f: - first_line = f.readline().strip() - if not first_line: - raise click.ClickException(f"{file_type.capitalize()} file appears to be empty") - except Exception as e: - raise click.ClickException(f"Error reading {file_type} file: {e}") - - return True - -def _parse_csv_input(file_path: str, file_type: str = "prediction") -> PSMList: - """ - Parse CSV input file into PSMList. - - Parameters - ---------- - file_path : str - Path to CSV file - file_type : str - Type of file for error messages - - Returns - ------- - PSMList - Parsed PSM data - - Raises - ------ - click.ClickException - If parsing fails - """ - try: - df = pd.read_csv(file_path) - df = df.fillna("") - - required_cols = ["seq", "modifications", "charge"] - missing_cols = set(required_cols) - set(df.columns) - if missing_cols: - raise click.ClickException( - f"Missing required columns in {file_type} file: {missing_cols}\n" - f"Required columns: {required_cols}" - ) - - if file_type == "calibration" and "CCS" not in df.columns: - raise click.ClickException("Calibration file must contain 'CCS' column") - - list_of_psms = [] - for idx, row in df.iterrows(): - try: - peptidoform = peprec_to_proforma(row["seq"], row["modifications"], row["charge"]) - metadata = {} - if file_type == "calibration" and "CCS" in row: - metadata["CCS"] = float(row["CCS"]) - - psm = PSM(peptidoform=peptidoform, metadata=metadata, spectrum_id=idx) - list_of_psms.append(psm) - except Exception as e: - LOGGER.warning(f"Skipping row {idx} due to parsing error: {e}") - continue - - if not list_of_psms: - raise click.ClickException(f"No valid peptides found in {file_type} file") - - LOGGER.info(f"Parsed {len(list_of_psms)} peptides from {file_type} file") - return PSMList(psm_list=list_of_psms) - - except pd.errors.EmptyDataError: - raise click.ClickException(f"{file_type.capitalize()} file is empty") - except pd.errors.ParserError as e: - raise click.ClickException(f"Error parsing {file_type} file: {e}") - except Exception as e: - raise click.ClickException(f"Unexpected error reading {file_type} file: {e}") - - -# Command line interface with comprehensive options -@click.command() -@click.argument("psm-file", type=click.Path(exists=True, dir_okay=False), metavar="INPUT_FILE") +# Implement psm_utils reading for calibration and prediction PSMLists +@cli.command() +@click.pass_context +@click.argument( + "precursors", type=click.Path(exists=True, dir_okay=False), metavar="INPUT_FILE", required=True +) @click.option( "-c", - "--calibration-file", + "--calibration-precursors", type=click.Path(exists=True, dir_okay=False), default=None, - help="Path to calibration file with known CCS values. Highly recommended for accurate predictions.", + help="Path to file with precursors with known CCS values. If provided, calibration is performed.", ) @click.option( "-o", @@ -247,26 +131,19 @@ def _parse_csv_input(file_path: str, file_type: str = "prediction") -> PSMList: "--model-name", type=click.Choice(["tims"], case_sensitive=False), default="tims", - help="Neural network model to use for prediction.", + help="Neural network model to use for prediction. Currently only 'tims' is supported.", ) @click.option( "-e", "--multi", is_flag=True, default=False, - help="Enable multi-conformer prediction. Requires optional dependencies: pip install 'im2deep[er]'", -) -@click.option( - "-l", - "--log-level", - type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False), - default="info", - help="Set logging verbosity level.", + help="Enable multi-conformer prediction.", ) @click.option( "-n", - "--n-jobs", - type=click.IntRange(min=1), + "--processes", + type=int, default=None, help="Number of parallel jobs for model inference. Default uses all available CPU cores.", ) @@ -295,161 +172,286 @@ def _parse_csv_input(file_path: str, file_type: str = "prediction") -> PSMList: default=False, help="Output ion mobility (1/K0) instead of CCS values.", ) -def main( - psm_file: str, - calibration_file: Optional[str] = None, - output_file: Optional[str] = None, - model_name: str = "tims", - multi: bool = False, - log_level: str = "info", - n_jobs: Optional[int] = None, - use_single_model: bool = True, - calibrate_per_charge: bool = True, - use_charge_state: int = 2, - ion_mobility: bool = False, -) -> None: - """ - IM2Deep: Predict CCS values for peptides using deep learning. - - IM2Deep predicts Collisional Cross Section (CCS) values for peptides, - including those with post-translational modifications. The tool supports - both single-conformer and multi-conformer predictions with optional - calibration using reference datasets. - - INPUT_FILE should be a CSV file with columns: - \b - - seq: Peptide sequence (required) - - modifications: Modifications in format "position|name" (required, can be empty) - - charge: Charge state (required) - - For calibration files, an additional 'CCS' column with observed values is required. - - Examples: - \b - # Basic prediction - im2deep peptides.csv - - # With calibration (recommended) - im2deep peptides.csv -c calibration.csv +@click.option( + "-l", + "--logging-level", + type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False), + default="info", + help="Set logging verbosity level.", +) +def predict(ctx, *args, **kwargs): + """Predict CCS values for peptides (default command). - # Multi-conformer prediction - im2deep peptides.csv -c calibration.csv -e + If no calibration file is provided with -c, performs prediction only. + With -c, performs calibration and prediction for improved accuracy. + """ + # Check if profiling is enabled from parent context + profile_enabled = ctx.obj.get("profile", False) - # Ion mobility output - im2deep peptides.csv -c calibration.csv -i + if profile_enabled: + # Run with profiling + profiler = cProfile.Profile() + profiler.enable() - # Ensemble prediction with custom output - im2deep peptides.csv -c calibration.csv -o results.csv --use-single-model False - """ try: - # Setup logging first - setup_logging(log_level) - - LOGGER.info("IM2Deep command-line interface started") - LOGGER.debug( - f"Input arguments: psm_file={psm_file}, calibration_file={calibration_file}, " - f"multi={multi}, ion_mobility={ion_mobility}" - ) + _run_predict(*args, **kwargs) + finally: + if profile_enabled: + profiler.disable() + + # Get the IM2Deep root directory (two levels up from this file) + root_dir = Path(__file__).parent.parent + profiles_dir = root_dir / "profiles" + profiles_dir.mkdir(exist_ok=True) + + profile_output = profiles_dir / ctx.obj.get("profile_name", "im2deep_profile.prof") + profiler.dump_stats(profile_output) + LOGGER.info(f"Profiling data saved to {profile_output}") + LOGGER.info(f"View with: snakeviz {profile_output}") + + +def _run_predict(*args, **kwargs): + """Internal function that performs the actual prediction work.""" + # Setup logging first + setup_logging(kwargs.get("logging_level", "info")) + + LOGGER.info("Starting IM2Deep CCS prediction...") + LOGGER.debug( + f"Input arguments: precursors={kwargs.get('precursors')}, " + f"calibration_precursors={kwargs.get('calibration_precursors')}, multi={kwargs.get('multi')}, " + f"ion_mobility={kwargs.get('ion_mobility')}" + ) - # Import main functionality (after logging setup) - from im2deep._exceptions import IM2DeepError - from im2deep.im2deep import predict_ccs - - # Check optional dependencies if multi-conformer requested - if multi: - check_optional_dependencies() - - # Validate input files - _validate_file_format(psm_file, "input") - if calibration_file: - _validate_file_format(calibration_file, "calibration") - - # Parse input files - LOGGER.info("Parsing input files...") - - # Try to determine file format - with open(psm_file, "r", encoding="utf-8") as f: - first_line = f.readline().strip() - - # Check if it's the expected CSV format - if "modifications" in first_line and "seq" in first_line: - psm_list_pred = _parse_csv_input(psm_file, "prediction") - df_pred = pd.read_csv(psm_file).fillna("") - else: - # Try psm_utils for other formats - try: - psm_list_pred = read_file(psm_file) - df_pred = None - LOGGER.info(f"Loaded {len(psm_list_pred)} PSMs using psm_utils") - except PSMUtilsIOException as e: - raise click.ClickException( - f"Could not parse input file. Expected CSV with columns 'seq', 'modifications', 'charge' " - f"or a format supported by psm_utils. Error: {e}" - ) - - # Parse calibration file - psm_list_cal = None - df_cal = None - if calibration_file: - with open(calibration_file, "r", encoding="utf-8") as f: - cal_first_line = f.readline().strip() - - if ( - "modifications" in cal_first_line - and "seq" in cal_first_line - and "CCS" in cal_first_line - ): - psm_list_cal = _parse_csv_input(calibration_file, "calibration") - df_cal = pd.read_csv(calibration_file).fillna("") - else: - raise click.ClickException( - "Calibration file must be CSV with columns: 'seq', 'modifications', 'charge', 'CCS'" - ) - else: - LOGGER.warning( - "No calibration file provided. Predictions will be uncalibrated. " - "Calibration is HIGHLY recommended for accurate results." - ) - - # Set up output file - if not output_file: - input_path = Path(psm_file) - output_file = input_path.parent / f"{input_path.stem}_IM2Deep-predictions.csv" - - LOGGER.info(f"Output will be written to: {output_file}") - - # Run prediction - LOGGER.info("Starting CCS prediction...") - predict_ccs( - psm_list_pred, - psm_list_cal, - output_file=output_file, - model_name=model_name, - multi=multi, - calibrate_per_charge=calibrate_per_charge, - use_charge_state=use_charge_state, - n_jobs=n_jobs, - use_single_model=use_single_model, - ion_mobility=ion_mobility, - pred_df=df_pred, - cal_df=df_cal, - write_output=True, + # Parse input files + LOGGER.info("Parsing input files...") + psm_list = parse_input(kwargs.get("precursors")) + + # Run prediction + LOGGER.info("Running CCS prediction...") + if kwargs.get("calibration_precursors"): + LOGGER.info("Calibration file provided, performing calibration and prediction...") + psm_list_cal = parse_input(kwargs.get("calibration_precursors")) + predictions = core.predict_and_calibrate(psm_list, psm_list_cal, *args, **kwargs) + else: + LOGGER.info( + "No calibration file provided (calibration is HIGHLY recommended), performing prediction only..." ) - - LOGGER.info("IM2Deep completed successfully!") - - except IM2DeepError as e: - LOGGER.error(f"IM2Deep error: {e}") - sys.exit(1) - except click.ClickException: - # Re-raise click exceptions to preserve formatting - raise - except Exception as e: - LOGGER.error(f"Unexpected error: {e}") - if log_level.lower() == "debug": - LOGGER.exception("Full traceback:") - sys.exit(1) + predictions = core.predict(*args, **kwargs) + + # Output results + LOGGER.info("IM2Deep CCS prediction completed successfully!") + output_name = kwargs.pop("output_file") + output_name = infer_output_name(kwargs["precursors"], output_name).with_suffix(".csv") + LOGGER.info(f"Writing output file to {output_name}...") + write_output(output_name, predictions, psm_list, kwargs.get("ion_mobility", False)) + LOGGER.info("Output file written successfully.") + LOGGER.info("IM2Deep finished.") + + +# TODO: implement train command +# @cli.command() +# @click.argument("training_data", type=click.Path(exists=True, dir_okay=False)) +# @click.option( +# "-o", +# "--output-model", +# type=click.Path(dir_okay=False), +# required=True, +# help="Path to save the trained model.", +# ) +# @click.option( +# "--epochs", +# type=int, +# default=100, +# help="Number of training epochs.", +# ) +# @click.option( +# "-l", +# "--logging-level", +# type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False), +# default="info", +# help="Set logging verbosity level.", +# ) +# def train(training_data, output_model, epochs, logging_level): +# """Train a new IM2Deep model. + +# Example: im2deep train training_data.csv -o my_model.ckpt +# """ +# setup_logging(logging_level) +# LOGGER.info("Starting IM2Deep training...") + +# # Parse training data +# psm_list_train = _parse_csv_input(training_data, "training") + +# # Call training function +# core.train( +# psm_list=psm_list_train, +# model_save_path=output_model, +# training_kwargs={"epochs": epochs}, +# ) + +# LOGGER.info(f"Training completed. Model saved to {output_model}") + + +def main(): + # try: + cli(obj={}) + # except Exception as e: + # LOGGER.error(f"Unexpected error in IM2Deep CLI: {e}") + # sys.exit(1) + + +# def main( +# psm_file: str, +# calibration_file: Optional[str] = None, +# output_file: Optional[str] = None, +# model_name: str = "tims", +# multi: bool = False, +# log_level: str = "info", +# n_jobs: Optional[int] = None, +# use_single_model: bool = True, +# calibrate_per_charge: bool = True, +# use_charge_state: int = 2, +# ion_mobility: bool = False, +# ) -> None: +# """ +# IM2Deep: Predict CCS values for peptides using deep learning. + +# IM2Deep predicts Collisional Cross Section (CCS) values for peptides, +# including those with post-translational modifications. The tool supports +# both single-conformer and multi-conformer predictions with optional +# calibration using reference datasets. + +# INPUT_FILE should be a CSV file with columns: +# \b +# - seq: Peptide sequence (required) +# - modifications: Modifications in format "position|name" (required, can be empty) +# - charge: Charge state (required) + +# For calibration files, an additional 'CCS' column with observed values is required. + +# Examples: +# \b +# # Basic prediction +# im2deep peptides.csv + +# # With calibration (recommended) +# im2deep peptides.csv -c calibration.csv + +# # Multi-conformer prediction +# im2deep peptides.csv -c calibration.csv -e + +# # Ion mobility output +# im2deep peptides.csv -c calibration.csv -i + +# # Ensemble prediction with custom output +# im2deep peptides.csv -c calibration.csv -o results.csv --use-single-model False +# """ +# try: +# # Setup logging first +# setup_logging(log_level) + +# LOGGER.info("IM2Deep command-line interface started") +# LOGGER.debug( +# f"Input arguments: psm_file={psm_file}, calibration_file={calibration_file}, " +# f"multi={multi}, ion_mobility={ion_mobility}" +# ) + +# # Import main functionality (after logging setup) +# from im2deep._exceptions import IM2DeepError +# from im2deep.im2deep import predict_ccs + +# # Validate input files +# _validate_file_format(psm_file, "input") +# if calibration_file: +# _validate_file_format(calibration_file, "calibration") + +# # Parse input files +# LOGGER.info("Parsing input files...") + +# # Try to determine file format +# with open(psm_file, "r", encoding="utf-8") as f: +# first_line = f.readline().strip() + +# # Check if it's the expected CSV format +# if "modifications" in first_line and "seq" in first_line: +# psm_list_pred = _parse_csv_input(psm_file, "prediction") +# df_pred = pd.read_csv(psm_file).fillna("") +# else: +# # Try psm_utils for other formats +# try: +# psm_list_pred = read_file(psm_file) +# df_pred = None +# LOGGER.info(f"Loaded {len(psm_list_pred)} PSMs using psm_utils") +# except PSMUtilsIOException as e: +# raise click.ClickException( +# f"Could not parse input file. Expected CSV with columns 'seq', 'modifications', 'charge' " +# f"or a format supported by psm_utils. Error: {e}" +# ) + +# # Parse calibration file +# psm_list_cal = None +# df_cal = None +# if calibration_file: +# with open(calibration_file, "r", encoding="utf-8") as f: +# cal_first_line = f.readline().strip() + +# if ( +# "modifications" in cal_first_line +# and "seq" in cal_first_line +# and "CCS" in cal_first_line +# ): +# psm_list_cal = _parse_csv_input(calibration_file, "calibration") +# df_cal = pd.read_csv(calibration_file).fillna("") +# else: +# raise click.ClickException( +# "Calibration file must be CSV with columns: 'seq', 'modifications', 'charge', 'CCS'" +# ) +# else: +# LOGGER.warning( +# "No calibration file provided. Predictions will be uncalibrated. " +# "Calibration is HIGHLY recommended for accurate results." +# ) + +# # Set up output file +# if not output_file: +# input_path = Path(psm_file) +# output_file = input_path.parent / f"{input_path.stem}_IM2Deep-predictions.csv" + +# LOGGER.info(f"Output will be written to: {output_file}") + +# # Run prediction +# LOGGER.info("Starting CCS prediction...") +# predict_ccs( +# psm_list_pred, +# psm_list_cal, +# output_file=output_file, +# model_name=model_name, +# multi=multi, +# calibrate_per_charge=calibrate_per_charge, +# use_charge_state=use_charge_state, +# n_jobs=n_jobs, +# use_single_model=use_single_model, +# ion_mobility=ion_mobility, +# pred_df=df_pred, +# cal_df=df_cal, +# write_output=True, +# ) + +# LOGGER.info("IM2Deep completed successfully!") + +# except IM2DeepError as e: +# LOGGER.error(f"IM2Deep error: {e}") +# sys.exit(1) +# except click.ClickException: +# # Re-raise click exceptions to preserve formatting +# raise +# except Exception as e: +# LOGGER.error(f"Unexpected error: {e}") +# if log_level.lower() == "debug": +# LOGGER.exception("Full traceback:") +# sys.exit(1) if __name__ == "__main__": main() + build_credits() diff --git a/im2deep/_architecture.py b/im2deep/_architecture.py new file mode 100644 index 0000000..11d02ff --- /dev/null +++ b/im2deep/_architecture.py @@ -0,0 +1,1625 @@ +import sys +from pathlib import Path +import torch +import torch.nn as nn +import torch.nn.functional as F +import lightning as L +import logging + +try: + import wandb +except ImportError: + wandb = None + +from im2deep.constants import ( + BASEMODELCONFIG, +) + +PACKAGE_DATA_PATH = Path(__file__).parent / "package_data" + +logger = logging.getLogger(__name__) + + +class LogLowestMAE(L.Callback): + def __init__(self, config): + super(LogLowestMAE, self).__init__() + self.bestMAE = float("inf") + self.config = config + + def on_validation_end(self, trainer, pl_module): + try: + currentMAE = trainer.callback_metrics["Validation MAE"] + except KeyError: # Multi + currentMAE = trainer.callback_metrics["Val Mean MAE"] + if currentMAE < self.bestMAE: + self.bestMAE = currentMAE + if self.config["wandb"]["enabled"]: + if wandb is not None: + wandb.log({"Best Val MAE": self.bestMAE}) + + +class LRelu_with_saturation(nn.Module): + def __init__(self, negative_slope, saturation): + super(LRelu_with_saturation, self).__init__() + self.negative_slope = negative_slope + self.saturation = saturation + self.leaky_relu = nn.LeakyReLU(self.negative_slope) + + def forward(self, x): + activated = self.leaky_relu(x) + return torch.clamp(activated, max=self.saturation) + + +class Conv1dActivation(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding, + initializer, + negative_slope, + saturation, + ): + super(Conv1dActivation, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) + self.initializer = initializer + self.activation = LRelu_with_saturation( + negative_slope=negative_slope, saturation=saturation + ) + + initializer(self.conv.weight, 0.0, 0.05) + + def forward(self, x): + return self.activation(self.conv(x)) + + +class DenseActivation(nn.Module): + def __init__(self, in_features, out_features, initializer, negative_slope, saturation): + super(DenseActivation, self).__init__() + self.linear = nn.Linear(in_features, out_features) + self.initializer = initializer + self.activation = LRelu_with_saturation( + negative_slope=negative_slope, saturation=saturation + ) + + initializer(self.linear.weight, 0.0, 0.05) + + def forward(self, x): + return self.activation(self.linear(x)) + + +class SelfAttention(nn.Module): + def __init__(self, feature_dim, heads=1): + super(SelfAttention, self).__init__() + self.feature_dim = feature_dim + self.heads = heads + # self.padded_dim = self.feature_dim + (self.feature_dim % self.heads) + self.query_dim = self.feature_dim // self.heads + self.extra_dim = self.feature_dim % self.heads + + self.query = nn.Linear( + self.feature_dim, + (self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0)) * self.heads, + ) + self.key = nn.Linear( + self.feature_dim, + (self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0)) * self.heads, + ) + self.value = nn.Linear( + self.feature_dim, + (self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0)) * self.heads, + ) + + self.fc_out = nn.Linear(self.feature_dim, self.feature_dim) + + def forward(self, x): + + batch_size, seq_len, feature_dim = x.size() + queries = self.query(x).view( + batch_size, + seq_len, + self.heads, + self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0), + ) + keys = self.key(x).view( + batch_size, + seq_len, + self.heads, + self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0), + ) + values = self.value(x).view( + batch_size, + seq_len, + self.heads, + self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0), + ) + + attention_scores = torch.einsum("bqhd,bkhd->bhqk", [queries, keys]) / (self.query_dim**0.5) + attention_scores = F.softmax(attention_scores, dim=-1) + + out = torch.einsum("bhqk,bkhd->bqhd", [attention_scores, values]) + + out = out.view( + batch_size, + seq_len, + self.heads * (self.query_dim + (self.extra_dim if self.extra_dim > 0 else 0)), + ) + out = out[:, :, : self.feature_dim] + out = self.fc_out(out) + return out + + +class Branch(nn.Module): + def __init__(self, input_size, output_size, add_layer=1, dropout_rate=0.0): + super(Branch, self).__init__() + self.add_layer = add_layer + if self.add_layer: + self.fc1 = nn.Linear(input_size, output_size) + # self.dropout = nn.Dropout(dropout_rate) + self.fcoutput = nn.Linear(output_size, 1) + else: + self.fcoutput = nn.Linear(input_size, 1) + + def forward(self, x): + if self.add_layer == 1: + x = F.relu(self.fc1(x)) + # x = self.dropout(x) + x = self.fcoutput(x) + + return x + + +class IM2Deep(L.LightningModule): + def __init__(self, config, criterion): + super(IM2Deep, self).__init__() + self.config = config + self.criterion = criterion + self.mae = nn.L1Loss() + + initi = self.configure_init() + + self.ConvAtomComp = nn.ModuleList() + self.ConvAtomComp.append( + Conv1dActivation( + 6, + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + nn.MaxPool1d( + self.config["AtomComp_MaxPool_kernel_size"], + self.config["AtomComp_MaxPool_kernel_size"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + nn.MaxPool1d( + self.config["AtomComp_MaxPool_kernel_size"], + self.config["AtomComp_MaxPool_kernel_size"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_out_channels_start"] // 4, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"] // 4, + self.config["AtomComp_out_channels_start"] // 4, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append(nn.Flatten()) + + self.ConvDiatomComp = nn.ModuleList() + self.ConvDiatomComp.append( + Conv1dActivation( + 6, + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append( + Conv1dActivation( + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append( + nn.MaxPool1d( + self.config["DiatomComp_MaxPool_kernel_size"], + self.config["DiatomComp_MaxPool_kernel_size"], + ) + ) + self.ConvDiatomComp.append( + Conv1dActivation( + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_out_channels_start"] // 2, + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append( + Conv1dActivation( + self.config["DiatomComp_out_channels_start"] // 2, + self.config["DiatomComp_out_channels_start"] // 2, + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append(nn.Flatten()) + + self.ConvGlobal = nn.ModuleList() + self.ConvGlobal.append( + DenseActivation( + 60, + self.config["Global_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvGlobal.append( + DenseActivation( + self.config["Global_units"], + self.config["Global_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvGlobal.append( + DenseActivation( + self.config["Global_units"], + self.config["Global_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + + self.OneHot = nn.ModuleList() + self.OneHot.append( + Conv1dActivation( + 20, + self.config["OneHot_out_channels"], + self.config["One_hot_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.OneHot.append( + Conv1dActivation( + self.config["OneHot_out_channels"], + self.config["OneHot_out_channels"], + self.config["One_hot_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.OneHot.append( + nn.MaxPool1d( + self.config["OneHot_MaxPool_kernel_size"], + self.config["OneHot_MaxPool_kernel_size"], + ) + ) + self.OneHot.append(nn.Flatten()) + + if config["add_X_mol"]: + self.MolDesc = nn.ModuleList() + self.MolDesc.append( + Conv1dActivation( + 13, + self.config["Mol_out_channels_start"], + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"], + self.config["Mol_out_channels_start"], + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + nn.MaxPool1d( + self.config["Mol_MaxPool_kernel_size"], + self.config["Mol_MaxPool_kernel_size"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"], + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + nn.MaxPool1d( + self.config["Mol_MaxPool_kernel_size"], + self.config["Mol_MaxPool_kernel_size"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_out_channels_start"] // 4, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"] // 4, + self.config["Mol_out_channels_start"] // 4, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append(nn.Flatten()) + + self.total_input_size = calculate_concat_shape(self.config) + logger.debug(f"Total input size: {self.total_input_size}") + + self.Concat = nn.ModuleList() + self.Concat.append( + DenseActivation( + self.total_input_size, + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + + self.Concat.append(nn.Linear(self.config["Concat_units"], 1)) + + def regularized_loss(self, y_hat, y): + standard_loss = self.criterion(y_hat, y) + l1_norm = sum(torch.norm(p, 1) for p in self.parameters()) + return standard_loss + self.config["L1_alpha"] * l1_norm + + def forward(self, atom_comp, diatom_comp, global_feats, one_hot, mol_desc=None): + + atom_comp = atom_comp.permute(0, 2, 1) + diatom_comp = diatom_comp.permute(0, 2, 1) + one_hot = one_hot.permute(0, 2, 1) + + for layer in self.ConvAtomComp: + atom_comp = layer(atom_comp) + + for layer in self.ConvDiatomComp: + diatom_comp = layer(diatom_comp) + for layer in self.ConvGlobal: + global_feats = layer(global_feats) + for layer in self.OneHot: + one_hot = layer(one_hot) + + if self.config["add_X_mol"]: + for layer in self.MolDesc: + mol_desc = layer(mol_desc) + + concatenated = torch.cat((atom_comp, diatom_comp, one_hot, global_feats), 1) + + if self.config["add_X_mol"]: + concatenated = torch.cat((concatenated, mol_desc), 1) + + for layer in self.Concat: + concatenated = layer(concatenated) + + output = concatenated + return output + + def training_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc).squeeze(1) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + + loss = self.regularized_loss(y_hat, y) + + self.log("Train loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Train MAE", + self.mae(y_hat, y), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def validation_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc).squeeze(1) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + loss = self.criterion(y_hat, y) + + self.log("Validation loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Validation MAE", + self.mae(y_hat, y), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def test_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc).squeeze(1) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + loss = self.criterion(y_hat, y) + + self.log("Test loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Test MAE", + self.mae(y_hat, y), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc).squeeze(1) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + return y_hat + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["learning_rate"]) + return optimizer + + def configure_init(self): + if (not self.config["init"]) or (self.config["init"] == "normal"): + return nn.init.normal_ + if self.config["init"] == "xavier": + return nn.init.xavier_normal_ + if self.config["init"] == "kaiming": + return nn.init.kaiming_normal_ + + +class IM2DeepMulti(L.LightningModule): + def __init__(self, config, criterion): + super(IM2DeepMulti, self).__init__() + self.config = config + self.criterion = criterion + + initi = self.configure_init() + + self.ConvAtomComp = nn.ModuleList() + self.ConvAtomComp.append( + Conv1dActivation( + 6, + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + nn.MaxPool1d( + self.config["AtomComp_MaxPool_kernel_size"], + self.config["AtomComp_MaxPool_kernel_size"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"], + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + nn.MaxPool1d( + self.config["AtomComp_MaxPool_kernel_size"], + self.config["AtomComp_MaxPool_kernel_size"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"] // 2, + self.config["AtomComp_out_channels_start"] // 4, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append( + Conv1dActivation( + self.config["AtomComp_out_channels_start"] // 4, + self.config["AtomComp_out_channels_start"] // 4, + self.config["AtomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvAtomComp.append(nn.Flatten()) + + self.ConvDiatomComp = nn.ModuleList() + self.ConvDiatomComp.append( + Conv1dActivation( + 6, + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append( + Conv1dActivation( + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append( + nn.MaxPool1d( + self.config["DiatomComp_MaxPool_kernel_size"], + self.config["DiatomComp_MaxPool_kernel_size"], + ) + ) + self.ConvDiatomComp.append( + Conv1dActivation( + self.config["DiatomComp_out_channels_start"], + self.config["DiatomComp_out_channels_start"] // 2, + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append( + Conv1dActivation( + self.config["DiatomComp_out_channels_start"] // 2, + self.config["DiatomComp_out_channels_start"] // 2, + self.config["DiatomComp_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvDiatomComp.append(nn.Flatten()) + + self.ConvGlobal = nn.ModuleList() + self.ConvGlobal.append( + DenseActivation( + 60, + self.config["Global_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvGlobal.append( + DenseActivation( + self.config["Global_units"], + self.config["Global_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.ConvGlobal.append( + DenseActivation( + self.config["Global_units"], + self.config["Global_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + + self.OneHot = nn.ModuleList() + self.OneHot.append( + Conv1dActivation( + 20, + self.config["OneHot_out_channels"], + self.config["One_hot_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.OneHot.append( + Conv1dActivation( + self.config["OneHot_out_channels"], + self.config["OneHot_out_channels"], + self.config["One_hot_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.OneHot.append( + nn.MaxPool1d( + self.config["OneHot_MaxPool_kernel_size"], + self.config["OneHot_MaxPool_kernel_size"], + ) + ) + self.OneHot.append(nn.Flatten()) + + if config["add_X_mol"]: + self.MolDesc = nn.ModuleList() + self.MolDesc.append( + Conv1dActivation( + 13, + self.config["Mol_out_channels_start"], + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"], + self.config["Mol_out_channels_start"], + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + nn.MaxPool1d( + self.config["Mol_MaxPool_kernel_size"], + self.config["Mol_MaxPool_kernel_size"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"], + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + nn.MaxPool1d( + self.config["Mol_MaxPool_kernel_size"], + self.config["Mol_MaxPool_kernel_size"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"] // 2, + self.config["Mol_out_channels_start"] // 4, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append( + Conv1dActivation( + self.config["Mol_out_channels_start"] // 4, + self.config["Mol_out_channels_start"] // 4, + self.config["Mol_kernel_size"], + padding="same", + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.MolDesc.append(nn.Flatten()) + + self.total_input_size = calculate_concat_shape(self.config) + logger.debug(f"Total input size: {self.total_input_size}") + + self.Concat = nn.ModuleList() + self.Concat.append( + DenseActivation( + self.total_input_size, + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + self.Concat.append( + DenseActivation( + self.config["Concat_units"], + self.config["Concat_units"], + initializer=initi, + negative_slope=self.config["LRelu_negative_slope"], + saturation=self.config["LRelu_saturation"], + ) + ) + + self.concat_input_size = calculate_concat_shape(self.config) + self.branches = nn.ModuleList( + [ + Branch( + self.config["Concat_units"], + config.get("BranchSize", 0), + add_layer=config.get("add_branch_layer", 0), + ), + Branch( + self.config["Concat_units"], + config.get("BranchSize", 0), + add_layer=config.get("add_branch_layer", 0), + ), + ] + ) + + def forward(self, atom_comp, diatom_comp, global_feats, one_hot, mol_desc=None): + atom_comp = atom_comp.permute(0, 2, 1) + diatom_comp = diatom_comp.permute(0, 2, 1) + one_hot = one_hot.permute(0, 2, 1) + + for layer in self.ConvAtomComp: + atom_comp = layer(atom_comp) + + for layer in self.ConvDiatomComp: + diatom_comp = layer(diatom_comp) + + for layer in self.ConvGlobal: + global_feats = layer(global_feats) + + for layer in self.OneHot: + one_hot = layer(one_hot) + + if self.config["add_X_mol"]: + for layer in self.MolDesc: + mol_desc = layer(mol_desc) + + concatenated = torch.cat((atom_comp, diatom_comp, one_hot, global_feats), 1) + + if self.config["add_X_mol"]: + concatenated = torch.cat((concatenated, mol_desc), 1) + + for layer in self.Concat: + concatenated = layer(concatenated) + + y_hat1 = self.branches[0](concatenated) + y_hat2 = self.branches[1](concatenated) + + return y_hat1, y_hat2 + + def training_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + y1, y2 = y[:, 0], y[:, 1] + + loss = self.criterion(y1, y2, y_hat1, y_hat2) + + l1_norm = sum(p.abs().sum() for p in self.parameters()) + total_loss = loss + self.config["L1_alpha"] * l1_norm + + meanmae = MeanMAESorted(y1, y2, y_hat1, y_hat2) + lowestmae = LowestMAESorted(y1, y2, y_hat1, y_hat2) + + self.log("Train Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Train Mean MAE", meanmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "Train Lowest MAE", lowestmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + return total_loss + + def validation_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + y1, y2 = y[:, 0], y[:, 1] + + loss = self.criterion(y1, y2, y_hat1, y_hat2) + meanmae = MeanMAESorted(y1, y2, y_hat1, y_hat2) + lowestmae = LowestMAESorted(y1, y2, y_hat1, y_hat2) + + self.log("Val Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log("Val Mean MAE", meanmae, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Val Lowest MAE", lowestmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + return loss + + def test_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + y1, y2 = y[:, 0], y[:, 1] + + loss = self.criterion(y1, y2, y_hat1, y_hat2) + meanmae = MeanMAESorted(y1, y2, y_hat1, y_hat2) + lowestmae = LowestMAESorted(y1, y2, y_hat1, y_hat2) + + self.log("Test Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Test Mean MAE", meanmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "Test Lowest MAE", lowestmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + return loss + + def predict_step(self, batch): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + return torch.hstack([y_hat1, y_hat2]) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["learning_rate"]) + return optimizer + + def configure_init(self): + if (not self.config["init"]) or (self.config["init"] == "normal"): + return nn.init.normal_ + if self.config["init"] == "xavier": + return nn.init.xavier_normal_ + if self.config["init"] == "kaiming": + return nn.init.kaiming_normal_ + + +class IM2DeepMultiTransfer(L.LightningModule): + def __init__(self, config, criterion): + super(IM2DeepMultiTransfer, self).__init__() + # TODO: config should be adapted in config file + self.config = config + self.criterion = criterion + self.l1_alpha = config["L1_alpha"] + + # Load the IM2Deep model + logger.debug("Loading backbone IM2Deep model") + self.backbone = IM2Deep.load_from_checkpoint( + config["backbone_SD_path"], config=config, criterion=criterion + ) + + self.ConvAtomComp = self.backbone.ConvAtomComp + self.ConvDiatomComp = self.backbone.ConvDiatomComp + self.ConvGlobal = self.backbone.ConvGlobal + self.OneHot = self.backbone.OneHot + + if self.config.get("add_X_mol", False) == True: + self.MolDesc = self.backbone.MolDesc + + self.concat = list(self.backbone.Concat.children())[:-1] + + self.concat_input_size = calculate_concat_shape(self.config) + try: + self.output_size = config["Concat_units"] + except KeyError: + self.output_size = BASEMODELCONFIG["Concat_units"] + + if self.config.get("Use_attention_concat", False): + self.SelfAttentionConcat = SelfAttention( + self.concat_input_size, config.get("Concatheads", 1) + ) + if self.config.get("Use_attention_output", False): + self.SelfAttentionOutput = SelfAttention( + config["Concat_units"], config.get("Outputheads", 1) + ) + + self.branches = nn.ModuleList( + [ + Branch( + config["Concat_units"], + config.get("BranchSize", None), + add_layer=config.get("add_branch_layer", 0), + ), + Branch( + config["Concat_units"], + config.get("BranchSize", None), + add_layer=config.get("add_branch_layer", 0), + ), + ] + ) + + def forward(self, atom_comp, diatom_comp, global_feats, one_hot, mol_desc=None): + atom_comp = atom_comp.permute(0, 2, 1) + diatom_comp = diatom_comp.permute(0, 2, 1) + one_hot = one_hot.permute(0, 2, 1) + + for layer in self.ConvAtomComp: + atom_comp = layer(atom_comp) + + for layer in self.ConvDiatomComp: + diatom_comp = layer(diatom_comp) + + for layer in self.ConvGlobal: + global_feats = layer(global_feats) + + for layer in self.OneHot: + one_hot = layer(one_hot) + + if self.config["add_X_mol"]: + for layer in self.MolDesc: + mol_desc = layer(mol_desc) + + concatenated = torch.cat((atom_comp, diatom_comp, one_hot, global_feats), 1) + + if self.config["add_X_mol"]: + concatenated = torch.cat((concatenated, mol_desc), 1) + + if self.config.get("Use_attention_concat", 0) == 1: + concatenated = self.SelfAttentionConcat(concatenated.unsqueeze(1)).squeeze(1) + + for layer in self.concat: + concatenated = layer(concatenated) + + if self.config.get("Use_attention_output", 0) == 1: + concatenated = self.SelfAttentionOutput(concatenated.unsqueeze(1)).squeeze(1) + + y_hat1 = self.branches[0](concatenated) + y_hat2 = self.branches[1](concatenated) + + return y_hat1, y_hat2 + + def training_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + y1, y2 = y[:, 0], y[:, 1] + + loss = self.criterion(y1, y2, y_hat1, y_hat2) + + l1_norm = sum(p.abs().sum() for p in self.parameters()) + total_loss = loss + self.l1_alpha * l1_norm + + meanmae = MeanMAESorted(y1, y2, y_hat1, y_hat2) + lowestmae = LowestMAESorted(y1, y2, y_hat1, y_hat2) + + self.log( + "Train Loss", total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "Train Mean MAE", meanmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "Train Lowest MAE", lowestmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + return total_loss + + def validation_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + y1, y2 = y[:, 0], y[:, 1] + + loss = self.criterion(y1, y2, y_hat1, y_hat2) + + meanmae = MeanMAESorted(y1, y2, y_hat1, y_hat2) + lowestmae = LowestMAESorted(y1, y2, y_hat1, y_hat2) + + self.log("Val Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log("Val Mean MAE", meanmae, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Val Lowest MAE", lowestmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + return loss + + def test_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + + y1, y2 = y[:, 0], y[:, 1] + + loss = self.criterion(y1, y2, y_hat1, y_hat2) + meanmae = MeanMAESorted(y1, y2, y_hat1, y_hat2) + lowestmae = LowestMAESorted(y1, y2, y_hat1, y_hat2) + + self.log("Test Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Test Mean MAE", meanmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "Test Lowest MAE", lowestmae, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + return loss + + def predict_step(self, batch, inference=False): + if self.config["add_X_mol"]: + if not inference: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + else: + atom_comp, diatom_comp, global_feats, one_hot, mol_desc = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + if not inference: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + else: + atom_comp, diatom_comp, global_feats, one_hot = batch + y_hat1, y_hat2 = self(atom_comp, diatom_comp, global_feats, one_hot) + return torch.hstack([y_hat1, y_hat2]) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["learning_rate"]) + return optimizer + + +class IM2DeepTransfer(L.LightningModule): + def __init__(self, config, criterion): + super(IM2DeepTransfer, self).__init__() + + self.config = config + self.criterion = criterion + self.l1_alpha = config["L1_alpha"] + self.mae = nn.L1Loss() + + # Load the IM2Deep model + logger.debug("Loading backbone IM2Deep model") + self.backbone = IM2Deep.load_from_checkpoint( + config["backbone_SD_path"], config=config, criterion=criterion + ) + + self.ConvAtomComp = self.backbone.ConvAtomComp + self.ConvDiatomComp = self.backbone.ConvDiatomComp + self.ConvGlobal = self.backbone.ConvGlobal + self.OneHot = self.backbone.OneHot + + if self.config.get("add_X_mol", False) == True: + self.MolDesc = self.backbone.MolDesc + + self.concat = self.backbone.Concat + + def forward(self, atom_comp, diatom_comp, global_feats, one_hot, mol_desc=None): + atom_comp = atom_comp.permute(0, 2, 1) + diatom_comp = diatom_comp.permute(0, 2, 1) + one_hot = one_hot.permute(0, 2, 1) + + for layer in self.ConvAtomComp: + atom_comp = layer(atom_comp) + + for layer in self.ConvDiatomComp: + diatom_comp = layer(diatom_comp) + + for layer in self.ConvGlobal: + global_feats = layer(global_feats) + + for layer in self.OneHot: + one_hot = layer(one_hot) + + if self.config["add_X_mol"]: + for layer in self.MolDesc: + mol_desc = layer(mol_desc) + + concatenated = torch.cat((atom_comp, diatom_comp, one_hot, global_feats), 1) + + if self.config["add_X_mol"]: + concatenated = torch.cat((concatenated, mol_desc), 1) + + for layer in self.concat: + concatenated = layer(concatenated) + + y_hat = concatenated + return y_hat + + def training_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + + loss = self.criterion(y_hat, y) + + l1_norm = sum(p.abs().sum() for p in self.parameters()) + total_loss = loss + self.l1_alpha * l1_norm + + self.log( + "Train Loss", total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "Train MAE", + self.mae(y_hat, y), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return total_loss + + def validation_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + + loss = self.criterion(y_hat, y) + + self.log("Validation Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Validation MAE", + self.mae(y_hat, y), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def test_step(self, batch, batch_idx): + if self.config["add_X_mol"]: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc) + else: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + + loss = self.criterion(y_hat, y) + + self.log("Test Loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log( + "Test MAE", + self.mae(y_hat, y), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def predict_step(self, batch, inference=False): + if self.config["add_X_mol"]: + if not inference: + atom_comp, diatom_comp, global_feats, one_hot, y, mol_desc = batch + else: + atom_comp, diatom_comp, global_feats, one_hot, mol_desc = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot, mol_desc).squeeze(1) + else: + if not inference: + atom_comp, diatom_comp, global_feats, one_hot, y = batch + else: + atom_comp, diatom_comp, global_feats, one_hot = batch + y_hat = self(atom_comp, diatom_comp, global_feats, one_hot).squeeze(1) + return y_hat + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["learning_rate"]) + return optimizer + + +class FlexibleLossSorted(nn.Module): + def __init__(self, diversity_weight=0.1): + super(FlexibleLossSorted, self).__init__() + self.diversity_weight = diversity_weight + + def forward(self, y1, y2, y_hat1, y_hat2): + loss_fn = nn.L1Loss() + + # Sort the targets and predictions row-wise + targets = torch.stack([y1, y2], dim=1) + predictions = torch.stack([y_hat1, y_hat2], dim=1) + targets, _ = torch.sort(targets, dim=1) + predictions, _ = torch.sort(predictions, dim=1) + + target1 = targets[:, 0] + + target2 = targets[:, 1] + + prediction1 = predictions[:, 0] + prediction1 = prediction1.squeeze() + + prediction2 = predictions[:, 1] + + prediction2 = prediction2.squeeze() + + loss1 = loss_fn(prediction1.float(), target1.float()) + + loss2 = loss_fn(prediction2.float(), target2.float()) + + target_diff = torch.abs(target1 - target2) + + prediction_diff = torch.abs(prediction1 - prediction2) + + diff_loss = loss_fn(prediction_diff.float(), target_diff.float()) + + total_loss = (loss1 + loss2) + (self.diversity_weight * diff_loss) + + return total_loss + + +class FlexibleLoss(nn.Module): + def __init__(self, diversity_weight=0.1): + super(FlexibleLoss, self).__init__() + self.diversity_weight = diversity_weight + + def forward(self, y1, y2, y_hat1, y_hat2): + loss_fn = nn.L1Loss() + + loss1_to_1 = loss_fn(y_hat1, y1) + loss2_to_2 = loss_fn(y_hat2, y2) + loss1_to_2 = loss_fn(y_hat1, y2) + loss2_to_1 = loss_fn(y_hat2, y1) + + loss_dict = { + "1_to_1": loss1_to_1, + "2_to_2": loss2_to_2, + "1_to_2": loss1_to_2, + "2_to_1": loss2_to_1, + } + min_loss_key = min(loss_dict, key=loss_dict.get) + if "1_to" in min_loss_key: + if "to_1" in min_loss_key: + loss1 = loss1_to_1 + loss2 = loss2_to_2 + else: + loss1 = loss1_to_2 + loss2 = loss2_to_1 + else: + if "to_2" in min_loss_key: + loss1 = loss2_to_2 + loss2 = loss1_to_1 + else: + loss1 = loss2_to_1 + loss2 = loss1_to_2 + + target_diff = torch.abs(y1 - y2) + prediction_diff = torch.abs(y_hat1 - y_hat2) + + diff_loss = loss_fn(prediction_diff, target_diff) + + total_loss = (loss1 + loss2) + (self.diversity_weight * diff_loss) + + return total_loss + + +def MeanMAESorted(y1, y2, y_hat1, y_hat2): + targets = torch.stack([y1, y2], dim=1) + predictions = torch.stack([y_hat1, y_hat2], dim=1) + # predictions is shape [x,2,1] but should be [x,2] + predictions = predictions.squeeze() + + targets, _ = torch.sort(targets, dim=1) + predictions, _ = torch.sort(predictions, dim=1) + + target1 = targets[:, 0] + target2 = targets[:, 1] + + prediction1 = predictions[:, 0] + prediction2 = predictions[:, 1] + + mae1 = MAE(prediction1, target1) + mae2 = MAE(prediction2, target2) + + return (mae1 + mae2) / 2 + + +def LowestMAESorted(y1, y2, y_hat1, y_hat2): + targets = torch.stack([y1, y2], dim=1) + predictions = torch.stack([y_hat1, y_hat2], dim=1) + predictions = predictions.squeeze() + + targets, _ = torch.sort(targets, dim=1) + predictions, _ = torch.sort(predictions, dim=1) + + target1 = targets[:, 0] + target2 = targets[:, 1] + + prediction1 = predictions[:, 0] + prediction2 = predictions[:, 1] + + mae1 = MAE(prediction1, target1) + mae2 = MAE(prediction2, target2) + + return min(mae1, mae2) + + +def MeanPearsonRSorted(y1, y2, y_hat1, y_hat2): + targets = torch.stack([y1, y2], dim=1) + predictions = torch.stack([y_hat1, y_hat2], dim=1) + + targets, _ = torch.sort(targets, dim=1) + predictions, _ = torch.sort(predictions, dim=1) + + target1 = targets[:, 0] + target2 = targets[:, 1] + + prediction1 = predictions[:, 0] + prediction2 = predictions[:, 1] + + r1 = pearsonr(target1, prediction1)[0] + r2 = pearsonr(target2, prediction2)[0] + + return (r1 + r2) / 2 + + +def MeanMRE(y1, y2, y_hat1, y_hat2): + mre1 = torch.median(torch.abs((y_hat1 - y1) / y1)) + mre2 = torch.median(torch.abs((y_hat2 - y2) / y2)) + return (mre1 + mre2) / 2 + + +def calculate_concat_shape(config): + atom_comp_out_shape = (60 // (2 * config["AtomComp_MaxPool_kernel_size"])) * ( + config["AtomComp_out_channels_start"] // 4 + ) + logger.debug(f"AtomComp out shape: {atom_comp_out_shape}") + diatom_comp_out_shape = (30 // (config["DiatomComp_MaxPool_kernel_size"])) * ( + config["DiatomComp_out_channels_start"] // 2 + ) + logger.debug(f"DiatomComp out shape: {diatom_comp_out_shape}") + globals_out_shape = config["Global_units"] + logger.debug(f"Globals out shape: {globals_out_shape}") + onehot_comp_out_shape = (60 // (config["OneHot_MaxPool_kernel_size"])) * config[ + "OneHot_out_channels" + ] + logger.debug(f"OneHot out shape: {onehot_comp_out_shape}") + + if config["add_X_mol"]: + mol_desc_comp_out_shape = (60 // (2 * config["Mol_MaxPool_kernel_size"])) * ( + config["Mol_out_channels_start"] // 4 + ) + logger.debug(f"MolDesc out shape: {mol_desc_comp_out_shape}") + total_input_size = ( + atom_comp_out_shape + + diatom_comp_out_shape + + globals_out_shape + + onehot_comp_out_shape + + mol_desc_comp_out_shape + ) + + else: + total_input_size = ( + atom_comp_out_shape + diatom_comp_out_shape + globals_out_shape + onehot_comp_out_shape + ) + + return total_input_size diff --git a/im2deep/_exceptions.py b/im2deep/_exceptions.py index 525b5b0..896d0c7 100644 --- a/im2deep/_exceptions.py +++ b/im2deep/_exceptions.py @@ -13,41 +13,43 @@ class IM2DeepError(Exception): """ Base exception class for all IM2Deep-related errors. - - This exception serves as the base class for all custom exceptions + + This exception serves as the base class for all custom exceptions in the IM2Deep package, allowing users to catch all package-specific errors with a single except clause. - + Attributes: message (str): Human readable string describing the exception. - + Example: >>> try: ... predict_ccs(invalid_data) ... except IM2DeepError as e: ... print(f"IM2Deep error occurred: {e}") """ + pass class CalibrationError(IM2DeepError): """ Exception raised when calibration-related errors occur. - + This exception is raised when there are issues with calibration data, reference datasets, or calibration procedures that prevent successful CCS calibration. - + Common scenarios: - Insufficient overlapping peptides between calibration and reference data - Invalid calibration file format - Missing required columns in calibration data - Numerical issues during calibration calculation - + Example: >>> try: ... linear_calibration(pred_df, cal_df, ref_df) ... except CalibrationError as e: ... print(f"Calibration failed: {e}") """ + pass diff --git a/im2deep/_model_ops.py b/im2deep/_model_ops.py new file mode 100644 index 0000000..5a8b8cf --- /dev/null +++ b/im2deep/_model_ops.py @@ -0,0 +1,196 @@ +# TODO: evaluate whether these functions can just be imported from DeepLC +"""Training, predicting, and evaluating using IM2Deep (PyTorch).""" + +from __future__ import annotations +import copy +import logging +import warnings +from os import PathLike +from pathlib import Path + +import torch +from rich.progress import track +from torch.utils.data import DataLoader, Dataset +import lightning as L + +# Suppress PyTorch padding warning for conv1d with even kernels and odd dilation +warnings.filterwarnings( + "ignore", + message="Using padding='same' with even kernel lengths and odd dilation.*", + category=UserWarning, + module="torch.nn.modules.conv", +) + +LOGGER = logging.getLogger(__name__) + + +def load_model( + model: torch.nn.Module | PathLike | str | None = None, + device: str | None = None, +) -> torch.nn.Module: + """Load a model from a file or return a randomly initialized model if none is provided.""" + # If device is not specified, use the default device (GPU if available, else CPU) + selected_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load model from file if a path is provided + if isinstance(model, (str, Path)): + checkpoint = torch.load(model, weights_only=False, map_location=selected_device) + + # Handle different checkpoint formats + if isinstance(checkpoint, dict): + # If it's a dictionary, it might be a checkpoint with 'model' or 'state_dict' key + if "model" in checkpoint: + loaded_model = checkpoint["model"] + elif "state_dict" in checkpoint: + # Need to initialize model architecture first, then load state dict + # For now, just extract the state dict + LOGGER.error( + "Checkpoint contains state_dict but no model architecture. " + "This format is not yet supported. Please provide a full model checkpoint." + ) + raise NotImplementedError( + "Loading from state_dict-only checkpoints is not yet implemented." + ) + else: + # Assume the entire dict is the model (some formats do this) + LOGGER.warning( + "Checkpoint is a dict but format is unclear. Attempting to use as-is." + ) + loaded_model = checkpoint + else: + # Direct model object + loaded_model = checkpoint + + elif isinstance(model, torch.nn.Module): + loaded_model = model + elif model is None: + # TODO: Implement randomly initialized model; requires model architecture definition + raise NotImplementedError("Loading randomly initialized model is not implemented yet.") + else: + raise TypeError(f"Expected a PyTorch Module or a file path, got {type(model)} instead.") + + # Ensure the model is on the specified device + if isinstance(loaded_model, torch.nn.Module): + loaded_model.to(selected_device) + loaded_model.eval() # Set model to evaluation mode + else: + raise TypeError( + f"Loaded model is not a PyTorch Module, got {type(loaded_model)} instead. " + f"The checkpoint file may be in an incompatible format." + ) + + return loaded_model + + +def predict( + model: torch.nn.Module | PathLike | str | None = None, + data: Dataset | None = None, + multi=False, + device: str = "cpu", + batch_size: int = 512, + num_workers: int = 0, +) -> torch.Tensor: + """ + Predict using a trained model. + + Parameters + ---------- + model + Trained model or path to model file. + data + Dataset to predict on. + device + Device to use for prediction. + batch_size + Batch size for prediction. + num_workers + Number of workers for data loading. + + Returns + ------- + torch.Tensor + Predictions. + + """ + # Check data first before loading model + if data is None: + raise ValueError("Data must be provided for prediction.") + + # TODO: implement custom model inference + LOGGER.debug("Loading model for prediction.") + model = _get_architecture( + multi=multi, + ).load_from_checkpoint( + checkpoint_path=model, + config=_get_model_config(multi=multi), + criterion=_get_loss_function(multi=multi), + ) + model.to(device) + LOGGER.debug(f"Model loaded on device: {device}") + + data_loader = DataLoader( + data, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + LOGGER.debug("DataLoader created for prediction.") + LOGGER.debug("Starting prediction loop.") + predictions = _predict_loop(model, data_loader, device) + return predictions.cpu().detach() + + +def _predict_loop( + model: torch.nn.Module, + data_loader: DataLoader, + device: str, +) -> torch.Tensor: + model.eval() + all_predictions = [] + with torch.no_grad(): + for features, _ in track(data_loader, description="Predicting", transient=True): + features = [feature_tensor.to(device) for feature_tensor in features] + outputs = model(*features) + if not isinstance(outputs, tuple): + # Single output + all_predictions.append(outputs.cpu()) + else: + # Multi-output: stack both predictions side by side + stacked = torch.stack([outputs[0], outputs[1]], dim=1) + all_predictions.append(stacked.cpu()) + + return torch.cat(all_predictions, dim=0).squeeze() + + +def _get_architecture(multi: bool) -> L.LightningModule: + """Get the model architecture based on whether multi-output is needed.""" + if multi: + from im2deep._architecture import IM2DeepMultiTransfer + + return IM2DeepMultiTransfer + else: + from im2deep._architecture import IM2Deep + + return IM2Deep + + +def _get_model_config(multi: bool) -> dict: + """Get the model configuration based on whether multi-output is needed.""" + if multi: + from im2deep.constants import DEFAULT_MULTI_CONFIG + + return DEFAULT_MULTI_CONFIG + else: + from im2deep.constants import DEFAULT_CONFIG + + return DEFAULT_CONFIG + + +def _get_loss_function(multi: bool) -> torch.nn.modules.loss._Loss | torch.nn.Module: + """Get the loss function based on whether multi-output is needed.""" + if multi: + from im2deep._architecture import FlexibleLossSorted + + return FlexibleLossSorted() + else: + return torch.nn.L1Loss() diff --git a/im2deep/calibrate.py b/im2deep/calibrate.py deleted file mode 100644 index fd8c551..0000000 --- a/im2deep/calibrate.py +++ /dev/null @@ -1,542 +0,0 @@ -""" -Calibration functions for CCS predictions in IM2Deep. - -This module provides functions for calibrating CCS predictions using reference datasets. Calibration is performed by calculating -shift factors based on overlapping peptides between calibration and reference data. - -The calibration process involves: -1. Finding overlapping peptide-charge pairs between calibration and reference datasets -2. Calculating mean CCS differences (shift factors) -3. Applying shifts to predictions either globally or per charge state - -Functions: - get_ccs_shift: Calculate global CCS shift factor for a specific charge state - get_ccs_shift_per_charge: Calculate CCS shift factors per charge state - calculate_ccs_shift: Wrapper function for shift calculation with validation - linear_calibration: Apply linear calibration to CCS predictions - -Example: - >>> calibrated_df = linear_calibration( - ... predictions_df, - ... calibration_df, - ... reference_df, - ... per_charge=True - ... ) -""" - -from __future__ import annotations - -import logging -from typing import cast - -import numpy as np -import pandas as pd - -from im2deep._exceptions import CalibrationError - -LOGGER = logging.getLogger(__name__) - - -def _validate_calibration_inputs( - cal_df: pd.DataFrame, - reference_dataset: pd.DataFrame, - required_cal_columns: list | None = None, - required_ref_columns: list | None = None, -) -> None: - """ - Validate input dataframes for calibration functions. - - Parameters - ---------- - cal_df - Calibration dataset - reference_dataset - Reference dataset - required_cal_columns - Required columns for calibration dataset - required_ref_columns - Required columns for reference dataset - - Raises - ------ - CalibrationError - If validation fails - - """ - if cal_df.empty: - raise CalibrationError("Calibration dataset is empty") - if reference_dataset.empty: - raise CalibrationError("Reference dataset is empty") - - if required_cal_columns: - missing_cols = set(required_cal_columns) - set(cal_df.columns) - if missing_cols: - raise CalibrationError(f"Missing columns in calibration data: {missing_cols}") - - if required_ref_columns: - missing_cols = set(required_ref_columns) - set(reference_dataset.columns) - if missing_cols: - raise CalibrationError(f"Missing columns in reference data: {missing_cols}") - - -def get_ccs_shift( - cal_df: pd.DataFrame, reference_dataset: pd.DataFrame, use_charge_state: int = 2 -) -> float: - """ - Calculate CCS shift factor for a specific charge state. - - This function calculates a constant offset based on identical precursors - between calibration and reference datasets for a specific charge state. - The shift represents how much the calibration CCS values differ from - reference CCS values on average. - - Parameters - ---------- - cal_df - PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' - reference_dataset - Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' - use_charge_state - Charge state to use for CCS shift calculation. Should be in range [2,4]. - - Returns - ------- - float - CCS shift factor. Positive values indicate calibration CCS is higher - than reference CCS on average. - - Raises - ------ - CalibrationError - If charge state is invalid or no overlapping peptides found - - Notes - ----- - The function: - 1. Filters both datasets to the specified charge state - 2. Merges on sequence and charge to find overlapping peptides - 3. Calculates mean difference: mean(ccs_observed - CCS_reference) - - Examples - -------- - >>> shift = get_ccs_shift(calibration_df, reference_df, use_charge_state=2) - >>> print(f"CCS shift factor: {shift:.2f} Ų") - - """ - # Validate inputs - _validate_calibration_inputs( - cal_df, - reference_dataset, - required_cal_columns=["sequence", "charge", "ccs_observed"], - required_ref_columns=["peptidoform", "charge", "CCS"], - ) - - if not 1 <= use_charge_state <= 6: - raise CalibrationError( - f"Invalid charge state {use_charge_state}. Should be between 1 and 6." - ) - - LOGGER.debug(f"Using charge state {use_charge_state} for CCS shift calculation.") - - # Filter data by charge state - reference_tmp = reference_dataset[reference_dataset["charge"] == use_charge_state] - df_tmp = cal_df[cal_df["charge"] == use_charge_state] - - if reference_tmp.empty: - LOGGER.warning(f"No reference data found for charge state {use_charge_state}") - return 0.0 - - if df_tmp.empty: - LOGGER.warning(f"No calibration data found for charge state {use_charge_state}") - return 0.0 - - # Merge datasets to find overlapping peptides - both = pd.merge( - left=reference_tmp, - right=df_tmp, - right_on=["sequence", "charge"], - left_on=["peptidoform", "charge"], - how="inner", - suffixes=("_ref", "_data"), - ) - - LOGGER.debug( - f"Calculating CCS shift based on {both.shape[0]} overlapping peptide-charge pairs " - f"between PSMs and reference dataset" - ) - - if both.empty: - LOGGER.warning("No overlapping peptides found between calibration and reference data") - return 0.0 - - if both.shape[0] < 10: - LOGGER.warning( - f"Only {both.shape[0]} overlapping peptides found. " - "Consider using more calibration data for reliable results." - ) - - # Calculate shift: how much calibration CCS is larger than reference CCS - shift = np.mean(both["ccs_observed"] - both["CCS"]) - - if abs(shift) > 100: # Sanity check for unreasonably large shifts - LOGGER.warning( - f"Large CCS shift detected ({shift:.2f} ƅ^2). " - "Please verify calibration and reference data quality." - ) - - return float(shift) - - -def get_ccs_shift_per_charge( - cal_df: pd.DataFrame, reference_dataset: pd.DataFrame -) -> dict[int, float]: - """ - Calculate CCS shift factors per charge state. - - This function calculates individual shift factors for each charge state - present in both calibration and reference datasets. This allows for - charge-specific calibration which often improves accuracy. - - Parameters - ---------- - cal_df - PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' - reference_dataset - Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' - - Returns - ------- - Dict[int, float] - Dictionary mapping charge states to their shift factors. - Keys are charge states (int), values are shift factors (float). - - Raises - ------ - CalibrationError - If required columns are missing or no overlapping data found - - Notes - ----- - The function: - 1. Merges calibration and reference data on sequence and charge - 2. Groups by charge state - 3. Calculates mean difference for each charge state - - Charge states with insufficient data (< 5 overlapping peptides) will be - logged as warnings but still included in results. - - Examples - -------- - >>> shifts = get_ccs_shift_per_charge(calibration_df, reference_df) - >>> print(shifts) - {2: 5.2, 3: 3.8, 4: 2.1} - - """ - # Validate inputs - _validate_calibration_inputs( - cal_df, - reference_dataset, - required_cal_columns=["sequence", "charge", "ccs_observed"], - required_ref_columns=["peptidoform", "charge", "CCS"], - ) - - # Merge datasets to find overlapping peptides - both = pd.merge( - left=reference_dataset, - right=cal_df, - right_on=["sequence", "charge"], - left_on=["peptidoform", "charge"], - how="inner", - suffixes=("_ref", "_data"), - ) - - if both.empty: - raise CalibrationError( - "No overlapping peptides found between calibration and reference data" - ) - - LOGGER.debug(f"Found {both.shape[0]} total overlapping peptide-charge pairs") - - # Check data distribution across charge states - charge_counts = both.groupby("charge").size() - LOGGER.debug(f"Peptides per charge state: {charge_counts.to_dict()}") - - # Warn about charge states with low data - low_data_charges = charge_counts[charge_counts < 5].index.tolist() - if low_data_charges: - LOGGER.warning( - f"Charge states with <5 peptides: {low_data_charges}. " - "Consider using global calibration for these charges." - ) - - # Calculate shift per charge state - shift_dict = ( - both.groupby("charge").apply(lambda x: np.mean(x["ccs_observed"] - x["CCS"])).to_dict() - ) - - # Convert numpy types to native Python types for JSON serialization - shift_dict = {int(k): float(v) for k, v in shift_dict.items()} - - # Check for unreasonably large shifts - large_shifts = {k: v for k, v in shift_dict.items() if abs(v) > 100} - if large_shifts: - LOGGER.warning(f"Large CCS shifts detected: {large_shifts}. Please verify data quality.") - - return shift_dict - - -def calculate_ccs_shift( - cal_df: pd.DataFrame, - reference_dataset: pd.DataFrame, - per_charge: bool = True, - use_charge_state: int | None = None, -) -> float | dict[int, float]: - """ - Calculate CCS shift factors with validation and filtering. - - This is the main interface for calculating CCS shift factors. It provides - input validation, charge filtering, and can return either global or - per-charge shift factors. - - Parameters - ---------- - cal_df - PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' - reference_dataset - Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' - per_charge - Whether to calculate shift factors per charge state. If False, calculates - a single global shift factor using the specified charge state. - use_charge_state - Charge state to use for global shift calculation when per_charge=False. - Should be in range [2,4]. Default is 2 if not specified. - - Returns - ------- - Union[float, Dict[int, float]] - If per_charge=True: Dictionary mapping charge states to shift factors - If per_charge=False: Single shift factor (float) - - Raises - ------ - CalibrationError - If validation fails or invalid parameters provided - - Notes - ----- - The function automatically filters out charges >6 as IM2Deep predictions - are not reliable for very high charge states. A warning is logged if - any peptides are filtered out. - - Examples - -------- - >>> # Per-charge calibration - >>> shifts = calculate_ccs_shift(cal_df, ref_df, per_charge=True) - >>> - >>> # Global calibration using charge 2 - >>> shift = calculate_ccs_shift(cal_df, ref_df, per_charge=False, use_charge_state=2) - - """ - # Validate inputs - _validate_calibration_inputs(cal_df, reference_dataset) - - if use_charge_state is not None and not 1 <= use_charge_state <= 6: - raise CalibrationError( - f"Invalid charge state {use_charge_state}. Should be between 1 and 6." - ) - - # Filter high charge states (IM2Deep predictions are unreliable >6) - original_size = len(cal_df) - cal_df = cal_df[cal_df["charge"] < 7].copy() - - if len(cal_df) < original_size: - filtered_count = original_size - len(cal_df) - LOGGER.info( - f"Filtered out {filtered_count} peptides with charge >6 " - "(predictions not reliable for z>6)" - ) - - if cal_df.empty: - raise CalibrationError("No valid calibration data remaining after filtering") - - if not per_charge: - # Global calibration using specified charge state - if use_charge_state is None: - use_charge_state = 2 - LOGGER.debug("No charge state specified for global calibration, using charge 2") - - shift_factor = get_ccs_shift(cal_df, reference_dataset, use_charge_state) - LOGGER.debug(f"Global CCS shift factor: {shift_factor:.3f}") - return shift_factor - else: - # Per-charge calibration - shift_factor_dict = get_ccs_shift_per_charge(cal_df, reference_dataset) - LOGGER.debug(f"CCS shift factors per charge: {shift_factor_dict}") - return shift_factor_dict - - -def linear_calibration( - preds_df: pd.DataFrame, - calibration_dataset: pd.DataFrame, - reference_dataset: pd.DataFrame, - per_charge: bool = True, - use_charge_state: int | None = None, -) -> pd.DataFrame: - """ - Calibrate CCS predictions using linear calibration. - - This function performs linear calibration of CCS predictions by applying - shift factors calculated from overlapping peptides between calibration - and reference datasets. Calibration can be applied globally or per charge state. - - Parameters - ---------- - preds_df - PSMs with CCS predictions. Must contain 'predicted_ccs' column. - Will be modified to include 'charge' and 'shift' columns. - calibration_dataset - Calibration dataset with observed CCS values. Must contain columns: - 'peptidoform', 'ccs_observed' - reference_dataset - Reference dataset with CCS values. Must contain columns: - 'peptidoform', 'CCS' - per_charge - Whether to calculate and apply shift factors per charge state. - If True, uses charge-specific calibration with fallback to global shift. - If False, applies single global shift factor. - use_charge_state - Charge state to use for global shift calculation when per_charge=False. - Default is 2 if not specified. - - Returns - ------- - pd.DataFrame - Calibrated PSMs with updated 'predicted_ccs' values and added 'shift' column. - - Raises - ------ - CalibrationError - If calibration fails due to data issues or missing columns - - Notes - ----- - The calibration process: - 1. Extracts sequence and charge information from peptidoforms - 2. Calculates shift factors from calibration vs reference data - 3. Applies shifts to predictions - 4. For per-charge calibration: uses charge-specific shifts with global fallback - - Per-charge calibration is recommended as it typically provides better accuracy - by accounting for charge-dependent systematic biases. - - Examples - -------- - >>> # Per-charge calibration (recommended) - >>> calibrated_df = linear_calibration( - ... predictions_df, - ... calibration_df, - ... reference_df, - ... per_charge=True - ... ) - >>> - >>> # Global calibration using charge 2 - >>> calibrated_df = linear_calibration( - ... predictions_df, - ... calibration_df, - ... reference_df, - ... per_charge=False, - ... use_charge_state=2 - ... ) - """ - - LOGGER.info("Calibrating CCS values using linear calibration...") - - # Validate input dataframes - if preds_df.empty: - raise CalibrationError("Predictions dataframe is empty") - if "predicted_ccs" not in preds_df.columns: - raise CalibrationError("Predictions dataframe missing 'predicted_ccs' column") - - # Create working copy to avoid modifying original - preds_df = preds_df.copy() - calibration_dataset = calibration_dataset.copy() - reference_dataset = reference_dataset.copy() - - try: - # Extract sequence and charge from calibration peptidoforms - LOGGER.debug("Extracting sequence and charge from calibration peptidoforms...") - calibration_dataset["sequence"] = calibration_dataset["peptidoform"].apply( - lambda x: x.proforma.split("\\")[0] if hasattr(x, "proforma") else str(x).split("/")[0] - ) - calibration_dataset["charge"] = calibration_dataset["peptidoform"].apply( - lambda x: ( - x.precursor_charge if hasattr(x, "precursor_charge") else int(str(x).split("/")[1]) - ) - ) - - # Extract charge from reference peptidoforms - LOGGER.debug("Extracting charge from reference peptidoforms...") - reference_dataset["charge"] = reference_dataset["peptidoform"].apply( - lambda x: int(x.split("/")[1]) if isinstance(x, str) else x.precursor_charge - ) - - except (AttributeError, ValueError, IndexError) as e: - raise CalibrationError(f"Error parsing peptidoform data: {e}") from e - - if per_charge: - LOGGER.info("Calculating general shift factor for fallback...") - try: - general_shift = calculate_ccs_shift( - calibration_dataset, - reference_dataset, - per_charge=False, - use_charge_state=use_charge_state or 2, - ) - # per_charge=False returns float - general_shift = cast(float, general_shift) - except CalibrationError as e: - LOGGER.warning( - f"Could not calculate general shift factor: {e}. Using 0.0 as fallback." - ) - general_shift = 0.0 - - LOGGER.info("Calculating shift factors per charge state...") - shift_factor_dict = calculate_ccs_shift( - calibration_dataset, reference_dataset, per_charge=True - ) - # per_charge=True returns dict[int, float] - shift_factor_dict = cast(dict[int, float], shift_factor_dict) - - # Add charge information to predictions if not present - if "charge" not in preds_df.columns: - preds_df["charge"] = preds_df["peptidoform"].apply( - lambda x: x.precursor_charge if hasattr(x, "precursor_charge") else 2 - ) - - # Apply charge-specific shifts with fallback to general shift - preds_df["shift"] = preds_df["charge"].map(shift_factor_dict).fillna(general_shift) - preds_df["predicted_ccs"] = preds_df["predicted_ccs"] + preds_df["shift"] - - # Log calibration statistics - used_charges = set(shift_factor_dict.keys()) - fallback_charges = set(preds_df[preds_df["shift"] == general_shift]["charge"].unique()) - if fallback_charges: - LOGGER.info(f"Used charge-specific calibration for charges: {sorted(used_charges)}") - LOGGER.info(f"Used fallback calibration for charges: {sorted(fallback_charges)}") - - else: - # Global calibration - shift_factor = calculate_ccs_shift( - calibration_dataset, - reference_dataset, - per_charge=False, - use_charge_state=use_charge_state or 2, - ) - # per_charge=False returns floats - shift_factor = cast(float, shift_factor) - preds_df["predicted_ccs"] += shift_factor - preds_df["shift"] = shift_factor - LOGGER.info(f"Applied global shift factor: {shift_factor:.3f}") - - LOGGER.info("CCS values calibrated successfully.") - return preds_df diff --git a/im2deep/calibration.py b/im2deep/calibration.py new file mode 100644 index 0000000..999c110 --- /dev/null +++ b/im2deep/calibration.py @@ -0,0 +1,464 @@ +""" +CCS calibration utilities. + +This module provides calibration strategies to map predicted CCS values to the aligned target scale. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import cast + +import pandas as pd +import numpy as np +from psm_utils import PSMList, Peptidoform + +from im2deep._exceptions import CalibrationError +from im2deep.utils import parse_input +from im2deep.constants import DEFAULT_REFERENCE_DATASET_PATH, DEFAULT_MULTI_REFERENCE_DATASET_PATH + +LOGGER = logging.getLogger(__name__) + + +class Calibration(ABC): + """Abstract base class for CCS calibration methods.""" + + @abstractmethod + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + @property + @abstractmethod + def is_fitted(self) -> bool: + """Check if the calibration has been fitted.""" + ... + + @abstractmethod + def fit( + self, + target: PSMList, + source: PSMList, + ) -> None: + """Fit the calibration using target and source CCS values.""" + ... + + @abstractmethod + def transform( + self, + source: PSMList, + ) -> PSMList: + """Transform source CCS into the calibrated target space.""" + ... + + +class LinearCCSCalibration(Calibration): + """ + Linear calibration for CCS predictions. + + This class implements a simple linear calibration method for CCS predictions by + applying shift factors calculated from overlapping peptides between calibration + and reference datasets. Shift factor calculation can be performed globally or per + charge state. + + Parameters + ---------- + per_charge : bool, optional + Whether to calculate shift factors per charge state. Default is True. + use_charge_state : int or None, optional + Charge state to use for global shift calculation when per_charge is False. + Default is 2 if not specified. + """ + + def __init__(self, per_charge: bool = True, use_charge_state: int | None = None) -> None: + super().__init__() + self.per_charge = per_charge + self.use_charge_state = use_charge_state + self.fitted = False + self.charge_shifts: dict[int, float] = {} + self.general_shift: float | None = None + self.used_charges: set[int] = set() + self.reference_psm_list: PSMList | None = None + + @property + def is_fitted(self) -> bool: + return self.fitted + + def fit( + self, + psm_df_target: pd.DataFrame, + psm_df_source: pd.DataFrame | None = None, + multi: bool = False, + ) -> None: + """Fit the calibration using target and source CCS values.""" + if psm_df_source is None: + LOGGER.debug("No reference PSMList provided, loading default reference dataset.") + psm_df_source = get_default_reference(multi=multi) + + LOGGER.debug("Calculating calibration parameters...") + + if self.per_charge: + # For per-charge calibration, calculate shifts for all charges + LOGGER.debug("Calculating shift factors per charge state...") + try: + self.charge_shifts = self.calculate_ccs_shift( + psm_df_target, + psm_df_source, + ) + LOGGER.debug(f"Calculated charge-specific shifts: {self.charge_shifts}") + except CalibrationError as e: + LOGGER.warning( + f"Could not calculate charge-specific shift factors: {e}. Using 0.0 as fallback." + ) + self.charge_shifts = {charge: 0.0 for charge in range(1, 7)} + + # Set general shift as the mean of calculated charge shifts or charge 2 if available + if 2 in self.charge_shifts and self.charge_shifts[2] != 0.0: + self.general_shift = self.charge_shifts[2] + else: + # Use mean of non-zero charge shifts + available_shifts = [ + shift + for shift in self.charge_shifts.values() + if shift is not None and shift != 0.0 + ] + if available_shifts: + self.general_shift = float(np.mean(available_shifts)) + else: + self.general_shift = 0.0 + + # Fill in missing charge states with general shift + no_shift_calculated = [] + for charge in range(1, 7): + if ( + charge not in self.charge_shifts + or self.charge_shifts[charge] is None + or self.charge_shifts[charge] == 0.0 + ): + no_shift_calculated.append(charge) + self.charge_shifts[charge] = float(self.general_shift) + LOGGER.debug( + f"No shift factor calculated for charge states: {no_shift_calculated}. " + f"Using general shift: {self.general_shift:.3f}." + ) + else: + # For global calibration, calculate a single shift + try: + self.general_shift = self.calculate_ccs_shift( + psm_df_target, + psm_df_source, + ) + except CalibrationError as e: + LOGGER.warning( + f"Could not calculate general shift factor: {e}. Using 0.0 as fallback." + ) + self.general_shift = 0.0 + self.charge_shifts = {charge: self.general_shift for charge in range(1, 7)} + + self.used_charges = set(self.charge_shifts.keys()) + self.fitted = True + LOGGER.debug(f"CCS shift factors per charge: {self.charge_shifts}") + + def transform( + self, + psm_df: pd.DataFrame, + ) -> np.ndarray: + """Transform source CCS into the calibrated target space.""" + if not self.is_fitted: + raise CalibrationError("Calibration has not been fitted yet.") + + LOGGER.debug("Applying calibration to source CCS values...") + + if "peptidoform" not in psm_df.columns: + raise CalibrationError("Input DataFrame must contain 'peptidoform' column.") + + if not "predicted_CCS_uncalibrated" in psm_df.columns and "metadata" in psm_df.columns: + psm_df["predicted_CCS_uncalibrated"] = psm_df["metadata"].apply( + lambda x: ( + x["predicted_CCS_uncalibrated"] + if "predicted_CCS_uncalibrated" in x + else np.nan + ) + ) + + # Extract charge from peptidoform column efficiently + psm_df["charge"] = psm_df["peptidoform"].apply( + lambda x: int(str(x).split("/")[-1]) if isinstance(x, str) else x.precursor_charge + ) + + if self.per_charge: + # Per-charge calibration using vectorized map operation + psm_df["shift"] = psm_df["charge"].map(self.charge_shifts).fillna(0.0) + else: + # Global calibration - use same shift for all + psm_df["shift"] = self.general_shift + + # Apply shift, handling both scalar and array CCS values (for multiconformer predictions) + def apply_shift(ccs_value, shift_value): + if isinstance(ccs_value, (list, np.ndarray)): + # Multiconformer: apply shift to each conformer + return np.array(ccs_value, dtype=np.float32) + shift_value + else: + # Single value + return float(ccs_value + shift_value) + + psm_df["calibrated_CCS"] = psm_df.apply( + lambda row: apply_shift(row["predicted_CCS_uncalibrated"], row["shift"]), axis=1 + ) + + # Return as numpy object array to preserve multiconformer arrays + predicted_ccs_calibrated = np.empty(len(psm_df), dtype=object) + predicted_ccs_calibrated[:] = psm_df["calibrated_CCS"].tolist() + + return predicted_ccs_calibrated + + def calculate_ccs_shift( + self, + target_df: pd.DataFrame, + source_df: pd.DataFrame, + ) -> dict[int, float] | float: + """ + Calculate CCS shift factors between target and source PSMLists. + + Parameters + ---------- + target_df + DataFrame containing peptidoforms and observed CCS values from the target PSMList. + source_df + DataFrame containing peptidoforms and predicted CCS values from the source PSMList. + + Returns + ------- + dict[int, float] | float + Shift factors per charge state if per_charge is True, otherwise a single shift factor. + + Raises + ------ + CalibrationError + If no overlapping peptides are found for shift calculation. + Notes + ----- + The function automatically filters out charges >6 as IM2Deep predictions are not reliable for higher charge states. + A warning is logged if any peptides are filtered out. + """ + if self.use_charge_state is not None and not 1 <= self.use_charge_state <= 6: + raise CalibrationError( + f"Invalid charge state {self.use_charge_state} for global shift calculation." + ) + + if not self.per_charge: + # Global calibration using specified charge state + if self.use_charge_state is None: + self.use_charge_state = 2 # Default charge state + LOGGER.debug( + "No charge state specified for global calibration. Using default charge state 2 for global shift calculation." + ) + + shift_factor = self._compute_ccs_shift( + target_df, + source_df, + self.use_charge_state, + ) + LOGGER.debug(f"Global CCS shift factor: {shift_factor:.3f}") + return shift_factor + else: + # Per-charge calibration + shift_factor_dict = self._compute_ccs_shift_per_charge( + target_df, + source_df, + ) + + return shift_factor_dict + + @staticmethod + def _compute_ccs_shift( + target_df: pd.DataFrame, + source_df: pd.DataFrame, + charge_state: int, + ) -> float: + """Compute CCS shift for a specific charge state using DataFrame operations.""" + # Prepare DataFrames with proper columns + target_work = target_df.copy() + source_work = source_df.copy() + + # Extract peptide keys and charges + def get_peptide_key(pf): + if isinstance(pf, Peptidoform): + # For Peptidoform objects, convert proforma to string and strip charge suffix + return str(pf.proforma).rsplit("/", 1)[0] + else: + # For strings in format "PEPTIDE/charge", split off charge + return str(pf).rsplit("/", 1)[0] + + def get_charge(pf): + if isinstance(pf, Peptidoform): + return pf.precursor_charge + else: + return int(str(pf).split("/")[-1]) + + target_work["peptide_key"] = target_work["peptidoform"].apply(get_peptide_key) + target_work["charge"] = target_work["peptidoform"].apply(get_charge) + + # Extract CCS from metadata if it's not a direct column + if "CCS" not in target_work.columns and "metadata" in target_work.columns: + target_work["CCS"] = target_work["metadata"].apply( + lambda x: x.get("CCS", np.nan) if isinstance(x, dict) else np.nan + ) + + source_work["peptide_key"] = source_work["peptidoform"].apply(get_peptide_key) + source_work["charge"] = source_work["peptidoform"].apply(get_charge) + + # Filter by charge state + target_filtered = target_work[target_work["charge"] == charge_state].copy() + source_filtered = source_work[source_work["charge"] == charge_state].copy() + + # Merge on peptide key to find overlapping peptides + merged = pd.merge( + target_filtered[["peptide_key", "CCS"]], + source_filtered[["peptide_key", "CCS"]], + on="peptide_key", + suffixes=("_target", "_source"), + ) + + LOGGER.debug( + f"Number of overlapping peptides for charge state {charge_state}: {len(merged)}" + ) + + num_overlapping = len(merged) + + LOGGER.debug( + f"Calculating CCS shift based on {num_overlapping} overlapping peptides for charge state {charge_state}." + ) + + if num_overlapping == 0: + LOGGER.warning(f"No overlapping peptides found for charge state {charge_state}.") + return 0.0 + + if num_overlapping < 10: + LOGGER.warning( + f"Only {num_overlapping} overlapping peptides found for charge state {charge_state}. " + "Shift calculation may be unreliable." + ) + + # Calculate shift as mean difference + shift = (merged["CCS_target"] - merged["CCS_source"]).mean() + + if abs(shift) > 100.0: + LOGGER.warning( + f"Unusually large CCS shift ({shift:.2f}) detected for charge state {charge_state}." + " Please verify the calibration datasets." + ) + + return float(shift) + + @staticmethod + def _compute_ccs_shift_per_charge( + target_df: pd.DataFrame, + source_df: pd.DataFrame, + ) -> dict[int, float]: + """ + Calculate CCS shift factors per charge state using DataFrame groupby. + + Parameters + ---------- + target_df + DataFrame with peptidoforms and observed CCS values from the target PSMList. + source_df + DataFrame with peptidoforms and predicted CCS values from the source PSMList. + + Returns + ------- + dict[int, float] + Shift factors per charge state. + + Raises + ------ + CalibrationError + If no overlapping peptides are found for any charge state. + """ + # Prepare DataFrames with proper columns + target_work = target_df.copy() + source_work = source_df.copy() + + # Extract peptide keys and charges + def get_peptide_key(pf): + if isinstance(pf, Peptidoform): + # For Peptidoform objects, use proforma property which excludes charge + return str(pf.proforma).rsplit("/", 1)[0] + else: + # For strings in format "PEPTIDE/charge", split off charge + return str(pf).rsplit("/", 1)[0] + + def get_charge(pf): + if isinstance(pf, Peptidoform): + return pf.precursor_charge + else: + return int(str(pf).split("/")[-1]) + + target_work["peptide_key"] = target_work["peptidoform"].apply(get_peptide_key) + target_work["charge"] = target_work["peptidoform"].apply(get_charge) + + if "CCS" not in target_work.columns and "metadata" in target_work.columns: + target_work["CCS"] = target_work["metadata"].apply( + lambda x: x["CCS"] if "CCS" in x else np.nan + ) + + source_work["peptide_key"] = source_work["peptidoform"].apply(get_peptide_key) + source_work["charge"] = source_work["peptidoform"].apply(get_charge) + + # Merge on peptide key and charge to find overlapping peptides + merged = pd.merge( + target_work[["peptide_key", "charge", "CCS"]], + source_work[["peptide_key", "charge", "CCS"]], + on=["peptide_key", "charge"], + suffixes=("_target", "_source"), + ) + + if len(merged) == 0: + raise CalibrationError("No overlapping peptides found for shift calculation.") + + # Calculate shift per charge using groupby + merged["shift"] = merged["CCS_target"] - merged["CCS_source"] + shift_factors = merged.groupby("charge")["shift"].mean().to_dict() + + # Log information for each charge state + charge_counts = merged.groupby("charge").size() + for charge, count in charge_counts.items(): + if count < 10: + LOGGER.warning( + f"Only {count} overlapping peptides found for charge state {charge}. " + "Shift calculation may be unreliable." + ) + if abs(shift_factors[charge]) > 100.0: + LOGGER.warning( + f"Unusually large CCS shift ({shift_factors[charge]:.2f}) detected for charge state {charge}." + " Please verify the calibration datasets." + ) + + if len(shift_factors) == 0: + raise CalibrationError("No CCS shift factors could be calculated.") + + return shift_factors + + +def get_default_reference(multi: bool = False) -> pd.DataFrame: + """ + Get the default reference DataFrame for calibration. + + Parameters + ---------- + multi + Whether to use the multi-charge reference dataset. + + Returns + ------- + pd.DataFrame + Default reference DataFrame with 'peptidoform' and 'CCS' columns. + """ + reference_data_path = ( + DEFAULT_MULTI_REFERENCE_DATASET_PATH if multi else DEFAULT_REFERENCE_DATASET_PATH + ) + LOGGER.info(f"Loading default reference dataset from {reference_data_path}") + # dataset is in .gz format, so we need to extract it + reference_df = pd.read_csv(reference_data_path, compression="gzip", keep_default_na=False) + return reference_df diff --git a/im2deep/constants.py b/im2deep/constants.py new file mode 100644 index 0000000..36299ca --- /dev/null +++ b/im2deep/constants.py @@ -0,0 +1,90 @@ +from pathlib import Path + +# Paths and names for default models and reference datasets +DEFAULT_MODEL_NAME = "IM2DeepUni.ckpt" +DEFAULT_MODEL = Path(__file__).resolve().parent / "models" / "TIMS" / DEFAULT_MODEL_NAME +DEFAULT_MULTI_MODEL_NAME = "IM2DeepMulti.ckpt" +DEFAULT_MULTI_MODEL = ( + Path(__file__).resolve().parent / "models" / "TIMS" / DEFAULT_MULTI_MODEL_NAME +) +DEFAULT_REFERENCE_DATASET_PATH = ( + Path(__file__).resolve().parent / "reference_data" / "reference_ccs.csv.gz" +) +DEFAULT_MULTI_REFERENCE_DATASET_PATH = ( + Path(__file__).parent / "reference_data" / "multi_reference_ccs.csv.gz" +) +MULTI_BACKBONE_PATH = Path(__file__).parent / "models" / "TIMS" / "multi_output_backbone.ckpt" + +# Constant values +SUMMARY_CONSTANT = 18509.8632163405 +MASS_GAS_N2 = 28.013 +TEMP = 31.85 +T_DIFF = 273.15 + +# Default model configuration +DEFAULT_MULTI_CONFIG = { + "model_name": "IM2DeepMulti", + "batch_size": 16, + "learning_rate": 0.0001, + "AtomComp_kernel_size": 4, + "DiatomComp_kernel_size": 2, + "One_hot_kernel_size": 2, + "AtomComp_out_channels_start": 256, + "DiatomComp_out_channels_start": 128, + "Global_units": 16, + "OneHot_out_channels": 2, + "Concat_units": 128, + "AtomComp_MaxPool_kernel_size": 2, + "DiatomComp_MaxPool_kernel_size": 2, + "OneHot_MaxPool_kernel_size": 10, + "LRelu_negative_slope": 0.1, + "LRelu_saturation": 20, + "L1_alpha": 0.00001, + "delta": 0, + "device": 0, + "add_X_mol": False, + "init": "normal", + "backbone_SD_path": MULTI_BACKBONE_PATH, +} + +DEFAULT_CONFIG = { + "model_name": "IM2DeepTorch2026ChargeDupes", + "batch_size": 512, + "learning_rate": 0.001, + "AtomComp_kernel_size": 4, + "DiatomComp_kernel_size": 2, + "One_hot_kernel_size": 2, + "AtomComp_out_channels_start": 256, + "DiatomComp_out_channels_start": 128, + "Global_units": 16, + "OneHot_out_channels": 2, + "Concat_units": 128, + "AtomComp_MaxPool_kernel_size": 2, + "DiatomComp_MaxPool_kernel_size": 2, + "OneHot_MaxPool_kernel_size": 10, + "LRelu_negative_slope": 0.1, + "LRelu_saturation": 20, + "L1_alpha": 0.000005, + "delta": 0, + "device": 0, + "add_X_mol": False, + "init": "normal", +} + +BASEMODELCONFIG = { + "AtomComp_kernel_size": 4, + "DiatomComp_kernel_size": 4, + "One_hot_kernel_size": 4, + "AtomComp_out_channels_start": 356, + "DiatomComp_out_channels_start": 65, + "Global_units": 20, + "OneHot_out_channels": 1, + "Concat_units": 94, + "AtomComp_MaxPool_kernel_size": 2, + "DiatomComp_MaxPool_kernel_size": 2, + "OneHot_MaxPool_kernel_size": 10, + "LRelu_negative_slope": 0.013545684190756122, + "LRelu_saturation": 40, + "init": "normal", + "add_X_mol": False, +} diff --git a/im2deep/core.py b/im2deep/core.py new file mode 100644 index 0000000..07b846e --- /dev/null +++ b/im2deep/core.py @@ -0,0 +1,178 @@ +"""IM2Deep core functionality.""" + +from __future__ import annotations + +import logging +from os import PathLike +from pathlib import Path + +import numpy as np +from psm_utils.psm_list import PSMList +import torch +from deeplc.data import DeepLCDataset + +from im2deep.utils import validate_psm_list +from im2deep import _model_ops +from im2deep.calibration import LinearCCSCalibration, Calibration +from im2deep.constants import DEFAULT_MODEL, DEFAULT_MULTI_MODEL + +LOGGER = logging.getLogger(__name__) + + +def predict( + psm_list: PSMList, + model: torch.nn.Module | PathLike | str | None = None, + multi=False, + predict_kwargs: dict | None = None, +) -> np.ndarray: + """ + Predict CCS values for a list of PSMs using a trained model. + + Parameters + ---------- + psm_list + List of PSMs to predict CCS values for. + model + Trained model or path to model file. If None, the default IM2Deep model is used. + predict_kwargs + Additional keyword arguments to pass to the prediction function. + + Returns + ------- + np.ndarray + CCS predictions. + + """ + LOGGER.info("Predicting CCS values using IM2Deep.") + psm_list = validate_psm_list(psm_list) + return _model_ops.predict( + model=model or DEFAULT_MODEL if not multi else DEFAULT_MULTI_MODEL, + data=DeepLCDataset.from_psm_list(psm_list, add_ccs_features=True), + multi=multi, + **(predict_kwargs or {}), + # TODO: check if "backbone" argument is needed for multi + ).numpy() + + +def predict_and_calibrate( + psm_list: PSMList, + psm_list_cal: PSMList, + psm_list_reference: PSMList | None = None, + model: torch.nn.Module | PathLike | str | None = None, + calibration: Calibration | None = None, + multi: bool = False, + predict_kwargs: dict | None = None, + **kwargs, +) -> np.ndarray: + """ + Calibrate and predict CCS values for a list of PSMs using a reference PSM list. + + Parameters + ---------- + psm_list + List of PSMs to predict CCS values for. + psm_list_reference + Reference list of PSMs for calibration. + model + Trained model or path to model file. If None, the default IM2Deep model is used. + calibration + Calibration object to use for calibration. If None, LinearCCSCalibration is applied. + predict_kwargs + Additional keyword arguments to pass to the prediction function. + + Returns + ------- + np.ndarray + Calibrated CCS predictions. + + """ + # Predict initial CCS values + LOGGER.info("Predicting uncalibrated CCS values...") + psm_list = validate_psm_list(psm_list) + psm_list_cal = validate_psm_list(psm_list_cal, needs_target=True) + + predicted_ccs = predict( + psm_list=psm_list, + model=model, + multi=multi, + predict_kwargs=predict_kwargs, + ) + + # Assign the predicted CCS to the PSM metadata + for idx, psm in enumerate(psm_list): + psm.metadata["predicted_CCS_uncalibrated"] = predicted_ccs[idx] + + psm_df = psm_list.to_dataframe() + psm_df_cal = psm_list_cal.to_dataframe() + if psm_list_reference is not None: + psm_list_reference = validate_psm_list(psm_list_reference, needs_target=True) + psm_df_reference = psm_list_reference.to_dataframe() + else: + psm_df_reference = None + + # Perform calibration + if calibration is None: + LOGGER.info("No calibration provided, using LinearCCSCalibration by default.") + calibration = LinearCCSCalibration( + per_charge=kwargs.get("calibrate_per_charge", True), + use_charge_state=( + kwargs.get("use_charge_state", 2) + if not kwargs.get("calibrate_per_charge", True) + else None + ), + ) + elif not isinstance(calibration, Calibration): + raise TypeError( + f"Calibration must be an instance of Calibration, got {type(calibration)} instead." + ) + + if not calibration.is_fitted: + LOGGER.info("Fitting calibration...") + if any(psm_list_cal["is_decoy"]): + LOGGER.warning( + "Calibration PSM list contains decoy PSMs. " + "These will be ignored during calibration fitting." + ) + calibration.fit( + psm_df_cal, + psm_df_reference, + multi=multi, + ) + else: + LOGGER.info("Calibration is already fitted, skipping fitting step.") + + # Apply calibration to predictions + predicted_ccs_calibrated = calibration.transform(psm_df) + + # Return as-is (already numpy array, may be object array for multiconformer) + return predicted_ccs_calibrated + + +def train( + psm_list, + model_save_path, + training_kwargs=None, +): + """ + Train a new IM2Deep model using the provided PSM list. + + Parameters + ---------- + psm_list + List of PSMs to use for training. + model_save_path + Path to save the trained model. + training_kwargs + Additional keyword arguments to pass to the training function. + + Returns + ------- + None + + """ + raise NotImplementedError( + "Training functionality is not yet implemented for IM2Deep. Use the IM2DeepTrainer package instead." + ) + + +# TODO: finetune and finetune_and_predict functions? diff --git a/im2deep/im2deep.py b/im2deep/im2deep.py deleted file mode 100644 index cb3215d..0000000 --- a/im2deep/im2deep.py +++ /dev/null @@ -1,476 +0,0 @@ -""" -Main CCS prediction module for IM2Deep. - -This module provides the core functionality for predicting Collisional Cross Section (CCS) -values for peptides using deep learning models. It supports both single-conformer and -multi-conformer predictions with optional calibration. - -The module handles: -- Loading and running neural network models for CCS prediction -- Calibrating predictions using reference datasets -- Converting between CCS and ion mobility -- Outputting results in various formats - -Functions: - predict_ccs: Main function for CCS prediction with optional calibration - -Dependencies: - - deeplc: For neural network model infrastructure - - psm_utils: For peptide data handling - - pandas/numpy: For data manipulation - -Example: - Basic CCS prediction: - >>> from im2deep.im2deep import predict_ccs - >>> predictions = predict_ccs(psm_list, calibration_data) - - Multi-conformer prediction: - >>> predictions = predict_ccs(psm_list, calibration_data, multi=True) -""" - -from __future__ import annotations - -import logging -from os import PathLike -from pathlib import Path -from typing import cast - -import pandas as pd -from deeplc import DeepLC -from psm_utils.psm_list import PSMList - -from im2deep._exceptions import IM2DeepError -from im2deep.calibrate import linear_calibration -from im2deep.utils import ccs2im - -LOGGER = logging.getLogger(__name__) -REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "reference_ccs.zip" - - -def _validate_inputs(psm_list_pred: PSMList, output_file: str | PathLike | None = None) -> None: - """ - Validate input parameters for prediction. - - Parameters - ---------- - psm_list_pred - PSM list for prediction - output_file - Output file path - - Raises - ------ - IM2DeepError - If validation fails - """ - if not isinstance(psm_list_pred, PSMList): - raise IM2DeepError("psm_list_pred must be a PSMList instance") - - if len(psm_list_pred) == 0: - raise IM2DeepError("PSM list for prediction is empty") - - if output_file and not isinstance(output_file, (str, PathLike)): - raise IM2DeepError("output_file must be a string or PathLike object") - - -def _get_model_paths(model_name: str, use_single_model: bool) -> list[Path]: - """ - Get model file paths based on model name and configuration. - - Parameters - ---------- - model_name - Name of the model ('tims') - use_single_model - Whether to use single model or ensemble - - Returns - ------- - list[Path] - List of model file paths - - Raises - ------ - IM2DeepError - If model files not found - """ - if model_name == "tims": - path_model = Path(__file__).parent / "models" / "TIMS" - else: - raise IM2DeepError(f"Unsupported model name: {model_name}") - - if not path_model.exists(): - raise IM2DeepError(f"Model directory not found: {path_model}") - - path_model_list = list(path_model.glob("*.keras")) - - if not path_model_list: - raise IM2DeepError(f"No model files found in {path_model}") - - if use_single_model: - # Use the third model by default (index 2) for consistency - if len(path_model_list) > 2: - selected_model = path_model_list[2] - LOGGER.debug(f"Using single model: {selected_model}") - return [selected_model] - else: - LOGGER.warning("Less than 3 models available, using first model") - return [path_model_list[0]] - else: - LOGGER.debug(f"Using ensemble of {len(path_model_list)} models") - return path_model_list - - -def _write_output_file( - output_file: str | PathLike, - psm_list_pred_df: pd.DataFrame, - pred_df: pd.DataFrame | None = None, - ion_mobility: bool = False, - multi: bool = False, -) -> None: - """ - Write predictions to output file. - - Parameters - ---------- - output_file - Path to output file - psm_list_pred_df - DataFrame with predictions - pred_df - Multi-conformer predictions - ion_mobility - Whether to output ion mobility instead of CCS - multi - Whether multi-conformer predictions are included - """ - if multi and pred_df is None: - raise IM2DeepError("Multi-conformer predictions requested but pred_df is None") - else: - pred_df = cast(pd.DataFrame, pred_df) - try: - with open(output_file, "w", encoding="utf-8") as f: - # TODO: Consider using dictwriter or Pandas to_csv - if ion_mobility: - if multi: - f.write( - "modified_seq,charge,predicted IM single,predicted IM multi 1,predicted IM multi 2\n" - ) - for peptidoform, charge, IM_single, IM_multi_1, IM_multi_2 in zip( - psm_list_pred_df["peptidoform"], - psm_list_pred_df["charge"], - psm_list_pred_df["predicted_im"], - psm_list_pred_df["predicted_im_multi_1"], - psm_list_pred_df["predicted_im_multi_2"], - strict=True, - ): - f.write(f"{peptidoform},{charge},{IM_single},{IM_multi_1},{IM_multi_2}\n") - else: - f.write("modified_seq,charge,predicted IM\n") - for peptidoform, charge, IM in zip( - psm_list_pred_df["peptidoform"], - psm_list_pred_df["charge"], - psm_list_pred_df["predicted_im"], - strict=True, - ): - f.write(f"{peptidoform},{charge},{IM}\n") - else: - if multi: - f.write( - "modified_seq,charge,predicted CCS single,predicted CCS multi 1,predicted CCS multi 2\n" - ) - for peptidoform, charge, CCS_single, CCS_multi_1, CCS_multi_2 in zip( - psm_list_pred_df["peptidoform"], - psm_list_pred_df["charge"], - psm_list_pred_df["predicted_ccs"], - pred_df["predicted_ccs_multi_1"], - pred_df["predicted_ccs_multi_2"], - strict=True, - ): - f.write( - f"{peptidoform},{charge},{CCS_single},{CCS_multi_1},{CCS_multi_2}\n" - ) - else: - f.write("modified_seq,charge,predicted CCS\n") - for peptidoform, charge, CCS in zip( - psm_list_pred_df["peptidoform"], - psm_list_pred_df["charge"], - psm_list_pred_df["predicted_ccs"], - strict=True, - ): - f.write(f"{peptidoform},{charge},{CCS}\n") - - LOGGER.info(f"Results written to: {output_file}") - - except OSError as e: - raise IM2DeepError(f"Failed to write output file {output_file}: {e}") from e - - -def predict_ccs( - psm_list_pred: PSMList, - psm_list_cal: PSMList | pd.DataFrame | None = None, - file_reference: PathLike | None = None, - output_file: PathLike | None = None, - model_name: str = "tims", - multi: bool = False, - calibrate_per_charge: bool = True, - use_charge_state: int = 2, - use_single_model: bool = True, - n_jobs: int | None = None, - write_output: bool = False, - ion_mobility: bool = False, - pred_df: pd.DataFrame | None = None, - cal_df: pd.DataFrame | None = None, -) -> pd.Series | pd.DataFrame: - """ - Predict CCS values for peptides using IM2Deep models. - - This is the main function for CCS prediction. It can perform single-conformer - or multi-conformer predictions with optional calibration using reference datasets. - - Parameters - ---------- - psm_list_pred - PSM list containing peptides for CCS prediction. Each PSM should contain - a valid peptidoform with sequence and modifications. - psm_list_cal - PSM list or DataFrame for calibration with observed CCS values. - If PSMList: CCS values should be in metadata with key "CCS". - If DataFrame: should have "ccs_observed" column. - Required for calibration. Default is None (no calibration). - file_reference - Path to reference dataset file for calibration. Default uses built-in - reference dataset. - output_file - Path to write output predictions. If None, no file is written. - model_name - Name of the model to use. Currently only "tims" is supported. - multi - Whether to include multi-conformer predictions. Requires optional - dependencies (torch, im2deeptrainer). - calibrate_per_charge - Whether to perform calibration per charge state. If False, uses - global calibration with specified charge state. - use_charge_state - Charge state to use for global calibration when calibrate_per_charge=False. - Should be in range [2,4] for best results. - use_single_model - Whether to use a single model (faster) or ensemble of models (potentially - more accurate). Single model recommended for most applications. - n_jobs - Number of parallel jobs for model prediction. If None, uses all available CPUs. - write_output - Whether to write predictions to output file. - ion_mobility - Whether to output ion mobility (1/K0) instead of CCS values. - pred_df - Pre-computed prediction DataFrame (used internally). - cal_df - Pre-computed calibration DataFrame (used internally). - - Returns - ------- - pd.Series or pd.DataFrame - If ion_mobility=True: Series with predicted ion mobility values - If ion_mobility=False: Series with predicted CCS values - For multi-conformer predictions, additional columns are included. - - Raises - ------ - IM2DeepError - If prediction fails due to invalid inputs, missing models, or other errors. - - Notes - ----- - The prediction workflow: - 1. Validate inputs and load appropriate models - 2. Generate CCS predictions using neural networks - 3. Apply calibration if calibration data provided - 4. Optionally run multi-conformer predictions - 5. Convert to ion mobility if requested - 6. Write output file if requested - - Calibration is highly recommended for accurate predictions and requires a set of peptides with - known CCS values that overlap with the reference dataset. - - Examples - -------- - Basic CCS prediction without calibration: - >>> predictions = predict_ccs(psm_list) - - CCS prediction with calibration: - >>> predictions = predict_ccs(psm_list, psm_list_calibration) - - Multi-conformer prediction with ion mobility output: - >>> predictions = predict_ccs( - ... psm_list, - ... psm_list_calibration, - ... multi=True, - ... ion_mobility=True - ... ) - - Ensemble prediction with file output: - >>> predictions = predict_ccs( - ... psm_list, - ... psm_list_calibration, - ... use_single_model=False, - ... output_file="predictions.csv", - ... write_output=True - ... ) - """ - LOGGER.info("IM2Deep started.") - - # Validate inputs - _validate_inputs(psm_list_pred, output_file) - - # Load reference dataset - if file_reference is None: - file_reference = REFERENCE_DATASET_PATH - - try: - reference_dataset = pd.read_csv(file_reference) - LOGGER.debug(f"Loaded reference dataset with {len(reference_dataset)} entries") - except Exception as e: - raise IM2DeepError(f"Failed to load reference dataset from {file_reference}: {e}") from e - - if reference_dataset.empty: - raise IM2DeepError("Reference dataset is empty") - - # Get model paths - try: - path_model_list = _get_model_paths(model_name, use_single_model) - except Exception as e: - raise IM2DeepError(f"Failed to load models: {e}") from e - - # Initialize DeepLC for CCS prediction - try: - dlc = DeepLC(path_model=path_model_list, n_jobs=n_jobs, predict_ccs=True) - LOGGER.info("Predicting CCS values...") - preds = dlc.make_preds(psm_list=psm_list_pred, calibrate=False) - LOGGER.info(f"CCS values predicted for {len(preds)} peptides.") - except Exception as e: - raise IM2DeepError(f"CCS prediction failed: {e}") from e - - if len(preds) == 0: - raise IM2DeepError("No predictions generated") - - # Convert PSM list to DataFrame and add predictions - try: - psm_list_pred_df = psm_list_pred.to_dataframe() - psm_list_pred_df["predicted_ccs"] = preds - psm_list_pred_df["charge"] = psm_list_pred_df["peptidoform"].apply( - lambda x: x.precursor_charge - ) - except Exception as e: - raise IM2DeepError(f"Failed to process predictions: {e}") from e - - # Apply calibration if calibration data provided - pred_df = None - if psm_list_cal is not None: - try: - LOGGER.info("Applying calibration...") - - # Handle both PSMList and DataFrame input - if isinstance(psm_list_cal, pd.DataFrame): - # Input is already a DataFrame with ccs_observed column - psm_list_cal_df = psm_list_cal.copy() - if "ccs_observed" not in psm_list_cal_df.columns: - raise IM2DeepError( - "DataFrame calibration data must contain 'ccs_observed' column" - ) - else: - # Input is PSMList, extract CCS from metadata - ccs_values = [] - for psm in psm_list_cal: - if psm.metadata and "CCS" in psm.metadata: - ccs_values.append(float(psm.metadata["CCS"])) - else: - ccs_values.append(None) - - # Convert to DataFrame and add CCS values - psm_list_cal_df = psm_list_cal.to_dataframe() - psm_list_cal_df["ccs_observed"] = ccs_values - - # Filter out entries without CCS values - psm_list_cal_df = psm_list_cal_df[psm_list_cal_df["ccs_observed"].notnull()] - - if psm_list_cal_df.empty: - LOGGER.warning("No valid calibration data found (missing CCS values)") - else: - psm_list_pred_df = linear_calibration( - psm_list_pred_df, - calibration_dataset=psm_list_cal_df, - reference_dataset=reference_dataset, - per_charge=calibrate_per_charge, - use_charge_state=use_charge_state, - ) - LOGGER.info("Calibration applied successfully.") - - except Exception as e: - LOGGER.error(f"Calibration failed: {e}") - # Continue without calibration rather than failing completely - LOGGER.warning("Continuing without calibration") - - # Multi-conformer prediction - if multi: - try: - from im2deep.predict_multi import predict_multi - - LOGGER.info("Predicting multiconformer CCS values...") - pred_df = predict_multi( - psm_list_pred, - cal_df, - calibrate_per_charge, - use_charge_state, - ) - LOGGER.info("Multiconformational predictions completed.") - except ImportError as e: - raise IM2DeepError( - "Multi-conformer prediction requires optional dependencies. " - "Please install with: pip install 'im2deep[er]'" - ) from e - except Exception as e: - raise IM2DeepError(f"Multi-conformer prediction failed: {e}") from e - - # Convert to ion mobility if requested - if ion_mobility: - try: - mz_array = psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz).to_numpy() - charge_array = psm_list_pred_df["charge"].to_numpy() - - psm_list_pred_df["predicted_im"] = ccs2im( - psm_list_pred_df["predicted_ccs"].to_numpy(), - mz_array, - charge_array, - ) - - if multi and pred_df is not None: - psm_list_pred_df["predicted_im_multi_1"] = ccs2im( - pred_df["predicted_ccs_multi_1"].to_numpy(), - mz_array, - charge_array, - ) - psm_list_pred_df["predicted_im_multi_2"] = ccs2im( - pred_df["predicted_ccs_multi_2"].to_numpy(), - mz_array, - charge_array, - ) - - except Exception as e: - raise IM2DeepError(f"Ion mobility conversion failed: {e}") from e - - # Write output file if requested - if write_output and output_file: - try: - _write_output_file(output_file, psm_list_pred_df, pred_df, ion_mobility, multi) - except Exception as e: - LOGGER.error(f"Failed to write output: {e}") - # Don't fail the entire prediction because of output issues - - LOGGER.info("IM2Deep finished!") - - # Return appropriate predictions - if ion_mobility: - return psm_list_pred_df["predicted_im"] - else: - return psm_list_pred_df["predicted_ccs"] diff --git a/im2deep/models/TIMS_multi/multi_output.ckpt b/im2deep/models/TIMS/IM2DeepMulti.ckpt similarity index 100% rename from im2deep/models/TIMS_multi/multi_output.ckpt rename to im2deep/models/TIMS/IM2DeepMulti.ckpt diff --git a/im2deep/models/TIMS/full_hc_peprec_CCS_v4_cb975cfdd4105f97efa0b3afffe075cc.keras b/im2deep/models/TIMS/IM2DeepUni.ckpt similarity index 56% rename from im2deep/models/TIMS/full_hc_peprec_CCS_v4_cb975cfdd4105f97efa0b3afffe075cc.keras rename to im2deep/models/TIMS/IM2DeepUni.ckpt index a0bf23f..26c6ec2 100644 Binary files a/im2deep/models/TIMS/full_hc_peprec_CCS_v4_cb975cfdd4105f97efa0b3afffe075cc.keras and b/im2deep/models/TIMS/IM2DeepUni.ckpt differ diff --git a/im2deep/models/TIMS/full_hc_peprec_CCS_v4_1fd8363d9af9dcad3be7553c39396960.keras b/im2deep/models/TIMS/full_hc_peprec_CCS_v4_1fd8363d9af9dcad3be7553c39396960.keras deleted file mode 100644 index 61ba2b9..0000000 Binary files a/im2deep/models/TIMS/full_hc_peprec_CCS_v4_1fd8363d9af9dcad3be7553c39396960.keras and /dev/null differ diff --git a/im2deep/models/TIMS/full_hc_peprec_CCS_v4_8c22d89667368f2f02ad996469ba157e.keras b/im2deep/models/TIMS/full_hc_peprec_CCS_v4_8c22d89667368f2f02ad996469ba157e.keras deleted file mode 100644 index eec433c..0000000 Binary files a/im2deep/models/TIMS/full_hc_peprec_CCS_v4_8c22d89667368f2f02ad996469ba157e.keras and /dev/null differ diff --git a/im2deep/models/TIMS_multi/multi_output_backbone.ckpt b/im2deep/models/TIMS/multi_output_backbone.ckpt similarity index 100% rename from im2deep/models/TIMS_multi/multi_output_backbone.ckpt rename to im2deep/models/TIMS/multi_output_backbone.ckpt diff --git a/im2deep/models/TIMS_multi/multi_output_pre_revision.ckpt b/im2deep/models/TIMS/multi_output_pre_revision.ckpt similarity index 100% rename from im2deep/models/TIMS_multi/multi_output_pre_revision.ckpt rename to im2deep/models/TIMS/multi_output_pre_revision.ckpt diff --git a/im2deep/predict_multi.py b/im2deep/predict_multi.py deleted file mode 100644 index 3e346a4..0000000 --- a/im2deep/predict_multi.py +++ /dev/null @@ -1,603 +0,0 @@ -""" -Multi-conformer CCS prediction module for IM2Deep. - -This module provides functionality for predicting CCS values for peptides that -can exist in multiple conformational states. It uses specialized neural network -models trained to predict multiple CCS values per peptide. - -The multi-conformer prediction pipeline: -1. Extract molecular features from peptide sequences -2. Run multi-output neural network models -3. Apply calibration using multi-conformer reference data -4. Return multiple CCS predictions per peptide - -Functions: - get_ccs_shift_multi: Calculate CCS shift for multi-conformer predictions - get_ccs_shift_per_charge_multi: Calculate per-charge shifts for multi predictions - calculate_ccs_shift_multi: Main shift calculation with validation - linear_calibration_multi: Apply calibration to multi-conformer predictions - predict_multi: Main function for multi-conformer CCS prediction - -Dependencies: - - torch: For neural network inference - - im2deeptrainer: For handling specialized im2deep models - - pandas/numpy: For data manipulation - -Note: - This module requires optional dependencies that can be installed with: - pip install 'im2deep[er]' -""" - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import cast - -import numpy as np -import pandas as pd - -try: - import torch - from im2deeptrainer.extract_data import _get_matrices # TODO: Should be public function? - from im2deeptrainer.model import IM2DeepMultiTransfer - from im2deeptrainer.utils import FlexibleLossSorted - - TORCH_AVAILABLE = True -except ImportError: - # Optional dependencies not available - torch = None - IM2DeepMultiTransfer = None - _get_matrices = None - FlexibleLossSorted = None - TORCH_AVAILABLE = False - -from im2deep._exceptions import CalibrationError, IM2DeepError -from im2deep.utils import multi_config - -LOGGER = logging.getLogger(__name__) -MULTI_CKPT_PATH: Path = Path(__file__).parent / "models" / "TIMS_multi" / "multi_output.ckpt" -REFERENCE_DATASET_PATH: Path = Path(__file__).parent / "reference_data" / "multi_reference_ccs.gz" - - -def _validate_multi_inputs(df_cal: pd.DataFrame, reference_dataset: pd.DataFrame) -> None: - """ - Validate inputs for multi-conformer calibration. - - Parameters - ---------- - df_cal - Calibration dataset - reference_dataset - Reference dataset - - Raises - ------ - CalibrationError - If validation fails - """ - required_cal_cols = ["seq", "modifications", "charge", "CCS"] - required_ref_cols = ["seq", "modifications", "charge", "ccs_observed"] - - if df_cal.empty: - raise CalibrationError("Calibration dataset is empty") - if reference_dataset.empty: - raise CalibrationError("Reference dataset is empty") - - missing_cal = set(required_cal_cols) - set(df_cal.columns) - if missing_cal: - raise CalibrationError(f"Missing columns in calibration data: {missing_cal}") - - missing_ref = set(required_ref_cols) - set(reference_dataset.columns) - if missing_ref: - raise CalibrationError(f"Missing columns in reference data: {missing_ref}") - - -def get_ccs_shift_multi( - df_cal: pd.DataFrame, reference_dataset: pd.DataFrame, use_charge_state: int = 2 -) -> float: - """ - Calculate CCS shift factor for multi-conformer predictions. - - This function calculates a shift factor specifically for multi-conformer - predictions by comparing calibration data with reference data for a - specific charge state. - - Parameters - ---------- - df_cal - Calibration peptides with observed CCS values. Must contain columns: - 'seq', 'modifications', 'charge', 'ccs_observed' - reference_dataset - Reference dataset with known CCS values. Must contain columns: - 'seq', 'modifications', 'charge', 'CCS' - use_charge_state - Charge state to use for CCS shift calculation. Recommended range [2,4]. - - Returns - ------- - float - CCS shift factor for multi-conformer predictions. Positive values - indicate calibration CCS is higher than reference on average. - - Raises - ------ - CalibrationError - If inputs are invalid or no overlapping data found - - Notes - ----- - Multi-conformer shift calculation differs from single-conformer by: - - Using sequence + modifications for matching instead of peptidoform - - Typically having fewer overlapping peptides due to stricter matching - - Requiring specific reference data trained for multi-conformer models - - Examples - -------- - >>> shift = get_ccs_shift_multi(cal_df, ref_df, use_charge_state=2) - >>> print(f"Multi-conformer shift: {shift:.2f} Ų") - """ - _validate_multi_inputs(df_cal, reference_dataset) - - if not use_charge_state <= 6: - raise CalibrationError(f"Invalid charge state {use_charge_state}") - - LOGGER.debug( - f"Using charge state {use_charge_state} for calibration of multi-conformer predictions." - ) - - # Filter by charge state - reference_tmp = reference_dataset[reference_dataset["charge"] == use_charge_state] - df_tmp = df_cal[df_cal["charge"] == use_charge_state] - - if reference_tmp.empty or df_tmp.empty: - LOGGER.warning( - f"No data found for charge state {use_charge_state} in multi-conformer calibration" - ) - return 0.0 - - # Merge on sequence and modifications for multi-conformer matching - both = pd.merge( - left=reference_tmp, - right=df_tmp, - on=["seq", "modifications"], - how="inner", - suffixes=("_ref", "_data"), - ) - - LOGGER.debug(f"Head of overlapping peptides:\n{both.head()}") - - LOGGER.debug( - "" - f"Calculating CCS shift based on {both.shape[0]} overlapping peptide-charge pairs " - f"between PSMs and reference dataset." - ) - - if both.empty: - LOGGER.warning("No overlapping peptides found for multi-conformer calibration") - return 0.0 - - if both.shape[0] < 10: - LOGGER.warning( - f"Only {both.shape[0]} overlapping peptides found for multi-conformer calibration. " - "Results may be unreliable." - ) - - # Calculate mean shift - shift = np.mean(both["ccs_observed"] - both["CCS"]) - - if abs(shift) > 50: - LOGGER.warning(f"Large multi-conformer CCS shift detected ({shift:.2f} ƅ^2)") - - return float(shift) - - -def get_ccs_shift_per_charge_multi( - df_cal: pd.DataFrame, reference_dataset: pd.DataFrame -) -> dict[int, float]: - """ - Calculate CCS shift factors per charge state for multi-conformer predictions. - - This function calculates charge-specific shift factors for multi-conformer - predictions, allowing for more accurate calibration across different - charge states. - - Parameters - ---------- - df_cal - Calibration peptides with observed CCS values. Must contain columns: - 'seq', 'modifications', 'charge', 'ccs_observed' - reference_dataset - Reference dataset with known CCS values. Must contain columns: - 'seq', 'modifications', 'charge', 'CCS' - - Returns - ------- - Dict[int, float] - Dictionary mapping charge states to their shift factors. - - Raises - ------ - CalibrationError - If inputs are invalid or no overlapping data found - - Notes - ----- - Multi-conformer per-charge calibration: - - Matches peptides exactly on sequence, modifications, and charge - - Typically yields fewer matches than single-conformer calibration - - Provides charge-specific corrections for systematic biases - - Examples - -------- - >>> shifts = get_ccs_shift_per_charge_multi(cal_df, ref_df) - >>> print("Multi-conformer shifts:", shifts) - {2: 4.1, 3: 2.8, 4: 1.5} - """ - _validate_multi_inputs(df_cal, reference_dataset) - - # Merge datasets for exact matching - both = pd.merge( - left=reference_dataset, - right=df_cal, - on=["seq", "modifications", "charge"], - how="inner", - suffixes=("_ref", "_data"), - ) - - if both.empty: - raise CalibrationError( - "No overlapping peptides found for multi-conformer per-charge calibration" - ) - - LOGGER.debug( - f"Found {both.shape[0]} overlapping peptides for multi-conformer per-charge calibration" - ) - - # Check data distribution - charge_counts = both.groupby("charge").size() - LOGGER.debug(f"Multi-conformer peptides per charge: {charge_counts.to_dict()}") - - # Warn about insufficient data - low_data_charges = charge_counts[charge_counts < 5].index.tolist() - if low_data_charges: - LOGGER.warning( - f"Charge states with <5 peptides in multi-conformer calibration: {low_data_charges}" - ) - - # Calculate shifts per charge - shift_dict = ( - both.groupby("charge").apply(lambda x: np.mean(x["ccs_observed"] - x["CCS"])).to_dict() - ) - - # Convert to native Python types - shift_dict = {int(k): float(v) for k, v in shift_dict.items()} - - return shift_dict - - -def calculate_ccs_shift_multi( - df_cal: pd.DataFrame, - reference_dataset: pd.DataFrame, - per_charge: bool = True, - use_charge_state: int | None = None, -) -> float | dict[int, float]: - """ - Calculate CCS shift factors for multi-conformer predictions with validation. - - This is the main interface for calculating shift factors for multi-conformer - predictions. It provides input validation, charge filtering, and supports - both global and per-charge calibration modes. - - Parameters - ---------- - df_cal - Calibration peptides with observed CCS values. - reference_dataset - Reference dataset with known CCS values. - per_charge - Whether to calculate shift factors per charge state. - use_charge_state - Charge state for global calibration when per_charge=False. - Default is 2 if not specified. - - Returns - ------- - float | dict[int, float] - If per_charge=True: Dictionary of shift factors per charge - If per_charge=False: Single global shift factor - - Raises - ------ - CalibrationError - If validation fails or invalid parameters - - Notes - ----- - Multi-conformer models are typically trained for charges 2-4, so higher - charges are filtered out automatically. The function logs filtering actions - for transparency. - - Examples - -------- - >>> # Per-charge calibration (recommended) - >>> shifts = calculate_ccs_shift_multi(cal_df, ref_df, per_charge=True) - >>> - >>> # Global calibration - >>> shift = calculate_ccs_shift_multi(cal_df, ref_df, per_charge=False, use_charge_state=2) - """ - _validate_multi_inputs(df_cal, reference_dataset) - - if use_charge_state is not None and not use_charge_state <= 6: - raise CalibrationError(f"Invalid charge state {use_charge_state}") - - # Filter charge states (multi-conformer models typically work best for 2-4) - original_size = len(df_cal) - df_cal = df_cal[(df_cal["charge"] < 5)].copy() - - if len(df_cal) < original_size: - filtered_count = original_size - len(df_cal) - LOGGER.info( - f"Filtered {filtered_count} peptides outside charge range 2-4 " - "for multi-conformer calibration" - ) - - if df_cal.empty: - raise CalibrationError( - "No valid calibration data for multi-conformer prediction after filtering" - ) - - if not per_charge: - if use_charge_state is None: - use_charge_state = 2 - LOGGER.debug("Using charge 2 for global multi-conformer calibration") - - shift_factor = get_ccs_shift_multi(df_cal, reference_dataset, use_charge_state) - LOGGER.debug(f"Multi-conformer global shift factor: {shift_factor:.3f}") - return shift_factor - else: - shift_factor_dict = get_ccs_shift_per_charge_multi(df_cal, reference_dataset) - LOGGER.debug(f"Multi-conformer shift factors: {shift_factor_dict}") - return shift_factor_dict - - -def linear_calibration_multi( - df_pred: pd.DataFrame, - df_cal: pd.DataFrame, - reference_dataset: pd.DataFrame, - per_charge: bool = True, - use_charge_state: int | None = None, -) -> pd.DataFrame: - """ - Calibrate multi-conformer CCS predictions using linear calibration. - - This function applies linear calibration specifically designed for - multi-conformer CCS predictions. It calculates and applies shift factors - to both conformer predictions. - - Parameters - ---------- - df_pred - DataFrame with multi-conformer CCS predictions. Must contain columns: - 'predicted_ccs_multi_1', 'predicted_ccs_multi_2', 'peptidoform' - df_cal - Calibration dataset with observed CCS values. - reference_dataset - Reference dataset for multi-conformer calibration. - per_charge - Whether to apply calibration per charge state. - use_charge_state - Charge state for global calibration when per_charge=False. - - Returns - ------- - pd.DataFrame - DataFrame with calibrated multi-conformer predictions. - - Raises - ------ - CalibrationError - If calibration fails - - Notes - ----- - Multi-conformer calibration: - - Applies the same shift to both conformer predictions - - Uses specialized reference data for multi-conformer models - - Supports both global and per-charge calibration strategies - - The calibration preserves the relative differences between conformers - while correcting systematic biases. - - Examples - -------- - >>> calibrated_df = linear_calibration_multi( - ... pred_df, cal_df, ref_df, per_charge=True - ... ) - """ - LOGGER.info("Calibrating multi-conformer predictions using linear calibration...") - - if df_pred.empty: - raise CalibrationError("Predictions dataframe is empty") - - required_cols = ["predicted_ccs_multi_1", "predicted_ccs_multi_2", "peptidoform"] - missing_cols = set(required_cols) - set(df_pred.columns) - if missing_cols: - raise CalibrationError(f"Missing columns in predictions: {missing_cols}") - - # Create working copy - df_pred = df_pred.copy() - - try: - if per_charge: - LOGGER.info("Generating general shift factor for multi-conformer predictions...") - general_shift = calculate_ccs_shift_multi( - df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state or 2 - ) - general_shift = cast(float, general_shift) # per_charge=False returns float - - LOGGER.info("Getting shift factors per charge state for multi-conformer...") - df_pred["charge"] = df_pred["peptidoform"].apply(lambda x: x.precursor_charge) - shift_factor_dict = calculate_ccs_shift_multi( - df_cal, reference_dataset, per_charge=True - ) - # per_charge=True returns dict[int, float] - shift_factor_dict = cast(dict[int, float], shift_factor_dict) - - # Apply charge-specific shifts with fallback - df_pred["shift_multi"] = df_pred["charge"].map(shift_factor_dict).fillna(general_shift) - df_pred["predicted_ccs_multi_1"] = ( - df_pred["predicted_ccs_multi_1"] + df_pred["shift_multi"] - ) - df_pred["predicted_ccs_multi_2"] = ( - df_pred["predicted_ccs_multi_2"] + df_pred["shift_multi"] - ) - - else: - shift_factor = calculate_ccs_shift_multi( - df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state or 2 - ) - shift_factor = cast(float, shift_factor) # per_charge=False returns float - df_pred["predicted_ccs_multi_1"] = df_pred["predicted_ccs_multi_1"] + shift_factor - df_pred["predicted_ccs_multi_2"] = df_pred["predicted_ccs_multi_2"] + shift_factor - df_pred["shift_multi"] = shift_factor - - LOGGER.info("Multi-conformer predictions calibrated successfully.") - return df_pred - - except Exception as e: - raise CalibrationError(f"Multi-conformer calibration failed: {e}") from e - - -def predict_multi( - df_pred_psm_list, - df_cal: pd.DataFrame | None, - calibrate_per_charge: bool, - use_charge_state: int, -) -> pd.DataFrame: - """ - Generate multi-conformer CCS predictions for peptides. - - This is the main function for multi-conformer CCS prediction. It loads - the specialized multi-output neural network model and generates predictions - for multiple conformational states of each peptide. - - Parameters - ---------- - df_pred_psm_list - PSM list containing peptides for prediction. - df_cal - Calibration dataset. If provided, predictions will be calibrated. - calibrate_per_charge - Whether to perform per-charge calibration. - use_charge_state - Charge state for global calibration. - - Returns - ------- - pd.DataFrame - DataFrame with columns 'predicted_ccs_multi_1' and 'predicted_ccs_multi_2' - containing CCS predictions for two conformational states. - - Raises - ------ - IM2DeepError - If multi-conformer prediction fails - - Notes - ----- - Multi-conformer prediction workflow: - 1. Extract molecular features using im2deeptrainer - 2. Load pre-trained multi-output model - 3. Generate predictions for two conformational states - 4. Apply calibration if calibration data provided - 5. Return predictions as DataFrame - - The model predicts two CCS values per peptide, representing the most - probable conformational states based on the training data. - - Examples - -------- - >>> multi_preds = predict_multi(psm_list, cal_df, True, 2) - >>> print(multi_preds.columns) - ['predicted_ccs_multi_1', 'predicted_ccs_multi_2'] - """ - # Check if optional dependencies are available - if not TORCH_AVAILABLE: - raise IM2DeepError( - "Multi-conformer prediction requires optional dependencies. " - "Please install with: pip install 'im2deep[er]'" - ) - - try: - # Initialize model components - criterion = FlexibleLossSorted() - - # Check if model file exists - if not MULTI_CKPT_PATH.exists(): - raise IM2DeepError(f"Multi-conformer model not found: {MULTI_CKPT_PATH}") - - model = IM2DeepMultiTransfer.load_from_checkpoint( - MULTI_CKPT_PATH, config=multi_config, criterion=criterion - ) - - LOGGER.debug("Multi-conformer model loaded successfully") - - # Extract molecular features - LOGGER.debug("Extracting molecular features for multi-conformer prediction...") - matrices = _get_matrices(df_pred_psm_list, inference=True) - - # Convert to tensors - tensors = {} - for key in matrices: - tensors[key] = torch.tensor(matrices[key]).type(torch.FloatTensor) - - # Create data loader - dataset = torch.utils.data.TensorDataset(*[tensors[key] for key in tensors]) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=multi_config["batch_size"], shuffle=False - ) - - # Generate predictions - model.eval() - with torch.no_grad(): - preds = [] - for batch in dataloader: - prediction = model.predict_step(batch, inference=True) - preds.append(prediction) - predictions = torch.cat(preds).numpy() - - LOGGER.debug(f"Generated multi-conformer predictions for {len(predictions)} peptides") - - # Convert PSM list to DataFrame and add predictions - df_pred = df_pred_psm_list.to_dataframe() - - if len(predictions) != len(df_pred): - raise IM2DeepError(f"Prediction count mismatch: {len(predictions)} vs {len(df_pred)}") - - df_pred["predicted_ccs_multi_1"] = predictions[:, 0] - df_pred["predicted_ccs_multi_2"] = predictions[:, 1] - - # Apply calibration if calibration data provided - if df_cal is not None: - try: - LOGGER.debug("Loading multi-conformer reference dataset...") - reference_dataset = pd.read_csv( - REFERENCE_DATASET_PATH, compression="gzip", keep_default_na=False - ) - - df_pred = linear_calibration_multi( - df_pred, - df_cal, - reference_dataset=reference_dataset, - per_charge=calibrate_per_charge, - use_charge_state=use_charge_state, - ) - except Exception as e: - LOGGER.warning(f"Multi-conformer calibration failed: {e}") - LOGGER.warning("Returning uncalibrated multi-conformer predictions") - - return df_pred[["predicted_ccs_multi_1", "predicted_ccs_multi_2"]] - - except Exception as e: - raise IM2DeepError(f"Multi-conformer prediction failed: {e}") from e diff --git a/im2deep/reference_data/multi_reference_ccs.gz b/im2deep/reference_data/multi_reference_ccs.csv.gz similarity index 61% rename from im2deep/reference_data/multi_reference_ccs.gz rename to im2deep/reference_data/multi_reference_ccs.csv.gz index 09f3c3b..8cf86a0 100644 Binary files a/im2deep/reference_data/multi_reference_ccs.gz and b/im2deep/reference_data/multi_reference_ccs.csv.gz differ diff --git a/im2deep/reference_data/reference_ccs.zip b/im2deep/reference_data/reference_ccs.csv.gz similarity index 99% rename from im2deep/reference_data/reference_ccs.zip rename to im2deep/reference_data/reference_ccs.csv.gz index 9b1dc3d..5e16071 100644 Binary files a/im2deep/reference_data/reference_ccs.zip and b/im2deep/reference_data/reference_ccs.csv.gz differ diff --git a/im2deep/utils.py b/im2deep/utils.py index 4485cd7..92a3290 100644 --- a/im2deep/utils.py +++ b/im2deep/utils.py @@ -15,23 +15,58 @@ from __future__ import annotations +import sys from pathlib import Path from typing import Any +import logging +from rich.text import Text +import gzip +import click import numpy as np +import psm_utils.io +import pandas as pd +from rich.console import Console +from rich.logging import RichHandler +from psm_utils.psm_list import PSMList +from psm_utils.psm import PSM -MULTI_BACKBONE_PATH = ( - Path(__file__).parent / "models" / "TIMS_multi" / "multi_output_backbone.ckpt" +from im2deep._exceptions import IM2DeepError +from im2deep.constants import ( + SUMMARY_CONSTANT, + MASS_GAS_N2, + TEMP, + T_DIFF, ) +console = Console() + +LOGGER = logging.getLogger(__name__) + + +def build_credits(): + """Build credits""" + text = Text() + text.append("\n") + text.append("IM2Deep\n", style="bold link https://github.com/compomics/im2deep") + text.append("Developed at CompOmics, VIB / Ghent University, Belgium.\n") + text.append("Please cite: ") + text.append( + "Devreese et al. Anal. Chem. (2025)", + style="link https://pubs.acs.org/doi/10.1021/acs.analchem.5c01142", + ) + text.append("\n") + text.stylize("cyan") + return text + def im2ccs( reverse_im: float | np.ndarray, mz: float | np.ndarray, charge: int | np.ndarray, - mass_gas: float = 28.013, - temp: float = 31.85, - t_diff: float = 273.15, + mass_gas: float = MASS_GAS_N2, + temp: float = TEMP, + t_diff: float = T_DIFF, ) -> float | np.ndarray: """ Convert reduced ion mobility to collisional cross section. @@ -100,7 +135,6 @@ def im2ccs( if temp <= -t_diff: raise ValueError("Temperature must be above absolute zero") - SUMMARY_CONSTANT = 18509.8632163405 reduced_mass = (mz * charge * mass_gas) / (mz * charge + mass_gas) return (SUMMARY_CONSTANT * charge) / (np.sqrt(reduced_mass * (temp + t_diff)) * 1 / reverse_im) @@ -109,9 +143,9 @@ def ccs2im( ccs: float | np.ndarray, mz: float | np.ndarray, charge: int | np.ndarray, - mass_gas: float = 28.013, - temp: float = 31.85, - t_diff: float = 273.15, + mass_gas: float = MASS_GAS_N2, + temp: float = TEMP, + t_diff: float = T_DIFF, ) -> float | np.ndarray: """ Convert collisional cross section to reduced ion mobility. @@ -178,34 +212,371 @@ def ccs2im( if temp <= -t_diff: raise ValueError("Temperature must be above absolute zero") - SUMMARY_CONSTANT = 18509.8632163405 reduced_mass = (mz * charge * mass_gas) / (mz * charge + mass_gas) return ((np.sqrt(reduced_mass * (temp + t_diff))) * ccs) / (SUMMARY_CONSTANT * charge) -# Configuration for multi-conformer model -multi_config: dict[str, Any] = { - "model_name": "IM2DeepMulti", - "batch_size": 16, - "learning_rate": 0.0001, - "AtomComp_kernel_size": 4, - "DiatomComp_kernel_size": 2, - "One_hot_kernel_size": 2, - "AtomComp_out_channels_start": 256, - "DiatomComp_out_channels_start": 128, - "Global_units": 16, - "OneHot_out_channels": 2, - "Concat_units": 128, - "AtomComp_MaxPool_kernel_size": 2, - "DiatomComp_MaxPool_kernel_size": 2, - "Mol_MaxPool_kernel_size": 2, - "OneHot_MaxPool_kernel_size": 10, - "LRelu_negative_slope": 0.1, - "LRelu_saturation": 20, - "L1_alpha": 0.00001, - "delta": 0, - "device": 0, - "add_X_mol": False, - "init": "normal", - "backbone_SD_path": MULTI_BACKBONE_PATH, -} +def parse_input( + input_file: str | Path | PSMList | pd.DataFrame, filetype: str | None = None +) -> PSMList: + """ + Parse input file or PSMList into a PSMList object. + + Parameters + ---------- + file_path : str, Path, or PSMList + Path to the input file or a PSMList object. + + Returns + ------- + PSMList + Parsed PSMList object. + """ + if isinstance(input_file, PSMList): + LOGGER.debug(f"Parsed {len(input_file)} PSMs from provided PSMList.") + return input_file + + if isinstance(input_file, pd.DataFrame): + LOGGER.debug(f"Parsing PSMs from provided DataFrame with {len(input_file)} rows.") + list_of_precursors = [] + + # Check if it's legacy format (has seq/modifications/charge) or standard format (has peptidoform) + has_peptidoform = "peptidoform" in input_file.columns + has_legacy_cols = all( + col in input_file.columns for col in ["seq", "modifications", "charge"] + ) + + for idx, row in input_file.iterrows(): + try: + if has_peptidoform: + # Standard format with peptidoform column + precursor = PSM(peptidoform=row["peptidoform"], spectrum_id=idx) + elif has_legacy_cols: + # Legacy format - convert to peptidoform + peptidoform = psm_utils.io.peptide_record.peprec_to_proforma( + peptide=row["seq"], + modifications=row["modifications"], + charge=int(row["charge"]), + ) + precursor = PSM(peptidoform=peptidoform, spectrum_id=idx) + else: + LOGGER.warning("Row %d missing required columns. Skipping.", idx) + continue + + if "CCS" in row: + precursor.metadata["CCS"] = float(row["CCS"]) + list_of_precursors.append(precursor) + except Exception as e: + LOGGER.warning("Error parsing row %d: %s. Skipping.", idx, e) + continue + + if not list_of_precursors: + raise IM2DeepError("No valid PSMs could be parsed from the DataFrame.") + + psm_list = PSMList(psm_list=list_of_precursors) + LOGGER.debug(f"Parsed {len(psm_list)} PSMs from DataFrame.") + return psm_list + + if not isinstance(input_file, (str, Path, PSMList)): + raise TypeError("input_file must be a str, Path, or PSMList.") + + LOGGER.info("Reading PSMs from file: %s", input_file) + + # First, check if it's a legacy format by inspecting the header + is_legacy_format = False + try: + # Read first line to check column names + with open(input_file, "r") as f: + first_line = f.readline().strip() + + # Check if it has legacy format columns + if "seq" in first_line.lower() and "modifications" in first_line.lower(): + # Additional check: legacy format should NOT have standard PSM format columns + if not any( + col in first_line.lower() for col in ["peptidoform", "protein", "spectrum_id"] + ): + is_legacy_format = True + LOGGER.debug("Detected legacy internal format based on header.") + except Exception as e: + LOGGER.debug(f"Could not pre-check file format: {e}") + + # Parse based on detected format + if is_legacy_format: + psm_list = _parse_legacy_format(input_file) + else: + # Try to parse with psm_utils + try: + psm_list = psm_utils.io.read_file(input_file, filetype=filetype or "infer") + LOGGER.debug(f"Successfully read file using psm_utils.") + except Exception as e: + # If psm_utils fails, try legacy format as fallback + LOGGER.warning(f"Failed to read PSM file using psm_utils: {e}") + LOGGER.info("Attempting to read as legacy internal format.") + psm_list = _parse_legacy_format(input_file) + + LOGGER.debug(f"Parsed {len(psm_list)} PSMs from file.") + return psm_list + + +def _parse_legacy_format(input_file: str | Path) -> PSMList: + """ + Parse legacy internal format delimited file. + + Expected columns: seq, modifications, charge, and optionally CCS. + Supports CSV, TSV, and other delimited formats. + + Parameters + ---------- + input_file : str or Path + Path to the legacy format file. + + Returns + ------- + PSMList + Parsed PSMList object. + + Raises + ------ + IM2DeepError + If required columns are missing or parsing fails. + """ + try: + # Use sep=None with engine='python' to auto-detect delimiter + df = pd.read_csv(input_file, sep=None, engine="python") + df = df.fillna("") # Replace NaN with empty strings + except Exception as e: + raise IM2DeepError(f"Failed to read file as delimited text: {e}") + + required_cols_legacy = ["seq", "modifications", "charge"] + missing_cols = set(required_cols_legacy) - set(df.columns) + + # Handle peprec format (uses 'peptide' instead of 'seq') + if "seq" not in df.columns and "peptide" in df.columns: + df.rename(columns={"peptide": "seq"}, inplace=True) + missing_cols = set(required_cols_legacy) - set(df.columns) + + if missing_cols: + raise IM2DeepError( + f"Legacy format file is missing required columns: {missing_cols}. " + f"Expected columns: seq (or peptide), modifications, charge" + ) + + has_ccs = "CCS" in df.columns + + list_of_precursors = [] + for idx, row in df.iterrows(): + metadata = {} + try: + peptidoform = psm_utils.io.peptide_record.peprec_to_proforma( + peptide=row["seq"], + modifications=row["modifications"], + charge=int(row["charge"]), + ) + if has_ccs: + metadata = {"CCS": float(row["CCS"])} + + LOGGER.debug(f"Parsed PSM: {peptidoform} with metadata: {metadata}") + precursor = PSM(peptidoform=peptidoform, metadata=metadata, spectrum_id=idx) + list_of_precursors.append(precursor) + except Exception as e: + LOGGER.warning("Error parsing row %d: %s. Skipping.", idx, e) + continue + + if not list_of_precursors: + raise IM2DeepError("No valid PSMs could be parsed from the legacy format file.") + + psm_list = PSMList(psm_list=list_of_precursors) + LOGGER.info(f"Successfully read {len(psm_list)} PSMs as legacy internal format.") + return psm_list + + +def validate_psm_list(psm_list: PSMList, needs_target: bool = False) -> PSMList: + """ + Validate that the PSM list contains necessary fields. And homogenizes the data. + Also removes entries with charge state higher than 6. + + Parameters + ---------- + psm_list : PSMList + The PSM list to validate. + needs_target : bool, optional + Whether target IM or CCS values are required. Default is False. + + Returns + ------- + PSMList + The validated and filtered PSM list. + """ + # Check if it's a PSMList + if not isinstance(psm_list, PSMList): + raise IM2DeepError( + f"Expected PSMList, got {type(psm_list).__name__}. " + "Please provide a valid PSMList object." + ) + + # Filter high charge states (IM2Deep predictions are not reliable for charges >6) + original_size = len(psm_list) + psm_list_filtered = PSMList( + psm_list=[psm for psm in psm_list if psm.peptidoform.precursor_charge <= 6] + ).copy() + + if len(psm_list_filtered) < original_size: + filtered_count = original_size - len(psm_list_filtered) + LOGGER.warning( + f"Filtered out {filtered_count} PSMs with charge states >6 for shift calculation.\n" + f"Predictions are not reliable for z>6." + ) + + if len(psm_list_filtered) == 0: + raise IM2DeepError("No PSMs present in provided PSMLists.") + + all_has_targets = True + if needs_target: + # Check if PSMs have either ion_mobility or CCS + all_has_targets = all( + psm.ion_mobility is not None or psm.metadata.get("CCS") is not None + for psm in psm_list_filtered + ) + + # If ion_mobility is present, convert to CCS + for psm in psm_list_filtered: + if psm.ion_mobility is not None and psm.metadata.get("CCS") is None: + psm.metadata["CCS"] = im2ccs( + psm.ion_mobility, + psm.peptidoform.theoretical_mz, + psm.peptidoform.precursor_charge, + ) + # Ensure CCS is always stored as float + elif psm.metadata.get("CCS") is not None: + ccs_value = psm.metadata["CCS"] + if not isinstance(ccs_value, float): + psm.metadata["CCS"] = float(ccs_value) + + if needs_target and not all_has_targets: + raise IM2DeepError("PSMList must contain 'ion_mobility' or 'CCS' metadata for all PSMs.") + + return psm_list_filtered + + +class DefaultCommandGroup(click.Group): + """Custom Click Group that invokes a default command if no subcommand is specified.""" + + def __init__(self, *args, **kwargs): + self.default_command = kwargs.pop("default_command", None) + super().__init__(*args, **kwargs) + + def resolve_command(self, ctx, args): + try: + # Try to resolve the command normally + return super().resolve_command(ctx, args) + except click.UsageError: + # If it fails and we have a default command, use that + if self.default_command and args: + # Get the default command + cmd_name = self.default_command + cmd = self.commands.get(cmd_name) + if cmd: + return cmd_name, cmd, args + # Re-raise the error if no default or command not found + raise + + +def setup_logging(passed_level: str) -> None: + """ + Configure logging with Rich formatting. + + Parameters + ---------- + passed_level : str + Logging level name (debug, info, warning, error, critical) + + Raises + ------ + ValueError + If invalid logging level provided + """ + log_mapping = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, + } + + if passed_level.lower() not in log_mapping: + raise ValueError( + f"Invalid log level: {passed_level}. " f"Should be one of {list(log_mapping.keys())}" + ) + + # Get the root logger and set its level + root_logger = logging.getLogger() + root_logger.setLevel(log_mapping[passed_level.lower()]) + + # Remove existing handlers to avoid duplicates + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # Add Rich handler + rich_handler = RichHandler( + rich_tracebacks=True, console=console, show_level=True, show_path=True + ) + rich_handler.setLevel(log_mapping[passed_level.lower()]) + root_logger.addHandler(rich_handler) + + # Also set the level for all existing loggers (including im2deep modules) + for logger_name in logging.Logger.manager.loggerDict: + if logger_name.startswith("im2deep"): + logger = logging.getLogger(logger_name) + logger.setLevel(log_mapping[passed_level.lower()]) + + +def infer_output_name( + input_filename: str, + output_name: str | None = None, +) -> Path: + """Infer output filename from input filename if output_filename was not defined.""" + if output_name: + return Path(output_name) + else: + input__filename = Path(input_filename) + return input__filename.with_name( + input__filename.stem + "_IM2Deep-predictions" + ).with_suffix("") + + +def write_output( + output_name: Path, predictions: np.ndarray, psm_list: PSMList, ion_mobility: bool = False +) -> None: + """ + Write the predictions to a CSV file. + + Parameters + ---------- + output_name : Path + The output file path. + predictions : np.ndarray + The predicted CCS values. + psm_list : PSMList + The original PSMList. + ion_mobility : bool, optional + Whether to include ion mobility in the output. Default is False. + """ + output_data = [] + for idx, psm in enumerate(psm_list): + entry = { + "index": psm.spectrum_id, + "peptidoform": str(psm.peptidoform), + "predicted_CCS": predictions[idx], + } + if ion_mobility: + im_value = ccs2im( + predictions[idx], + psm.peptidoform.theoretical_mz, + psm.peptidoform.precursor_charge, + ) + entry["predicted_ion_mobility"] = im_value + output_data.append(entry) + + output_df = pd.DataFrame(output_data) + output_df.to_csv(output_name, index=False) + LOGGER.info(f"Predictions written to {output_name}") diff --git a/pyproject.toml b/pyproject.toml index b6e0977..baf4518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ ] dynamic = ["version"] requires-python = ">=3.10" -dependencies = ["click", "deeplc<4", "psm_utils", "pandas", "numpy", "rich"] +dependencies = ["click", "deeplc", "psm_utils", "pandas", "numpy", "rich", "torch", "lightning"] [project.optional-dependencies] dev = [ @@ -40,7 +40,6 @@ docs = [ "sphinx_rtd_theme>=1.2", "sphinx-autobuild>=2021.3", ] -er = ["im2deeptrainer", "torch==2.3.0"] [project.urls] GitHub = "https://github.com/CompOmics/IM2Deep" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2b1961b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,17 @@ +[pytest] +minversion = 7.0 +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -ra + --strict-markers + --strict-config + --showlocals +markers = + integration: marks tests as integration tests (deselect with '-m "not integration"') + slow: marks tests as slow (deselect with '-m "not slow"') +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..076a00d --- /dev/null +++ b/tests/README.md @@ -0,0 +1,106 @@ +# IM2Deep Test Suite + +This directory contains comprehensive tests for the IM2Deep package. + +## Test Structure + +- `conftest.py`: Pytest configuration and shared fixtures +- `test_calibration.py`: Tests for calibration module +- `test_utils.py`: Tests for utility functions +- `test_model_ops.py`: Tests for model operations +- `test_core.py`: Tests for core functionality +- `test_cli.py`: Tests for command-line interface +- `test_exceptions.py`: Tests for custom exceptions +- `test_integration.py`: Integration tests for end-to-end workflows + +## Running Tests + +### Run all tests +```bash +pytest +``` + +### Run specific test file +```bash +pytest tests/test_calibration.py +``` + +### Run with coverage +```bash +pytest --cov=im2deep --cov-report=html +``` + +### Run with verbose output +```bash +pytest -v +``` + +### Run only fast tests (skip integration tests) +```bash +pytest -m "not integration" +``` + +### Run only integration tests +```bash +pytest -m integration +``` + +## Test Categories + +### Unit Tests +- `test_calibration.py`: LinearCCSCalibration class methods +- `test_utils.py`: Input parsing, validation, and conversion functions +- `test_model_ops.py`: Model loading and prediction functions +- `test_core.py`: High-level prediction and calibration functions +- `test_cli.py`: Command-line interface and argument parsing +- `test_exceptions.py`: Custom exception classes + +### Integration Tests +- `test_integration.py`: End-to-end workflows and data consistency + +## Fixtures + +Common fixtures are defined in `conftest.py`: + +- `sample_psm_list`: Basic PSMList for testing +- `sample_psm_list_with_ccs`: PSMList with CCS values for calibration +- `sample_reference_psm_list`: Reference PSMList for calibration +- `sample_peptidoforms`: Array of Peptidoform objects +- `sample_ccs_values`: Array of CCS values +- `sample_predicted_ccs`: Array of predicted CCS values (single-output) +- `sample_predicted_ccs_multi`: Array of predicted CCS values (multi-output) +- `temp_model_path`: Temporary file path for model testing +- `sample_legacy_format_df`: DataFrame in legacy format +- `sample_peprec_format_df`: DataFrame in PEPREC format + +## Test Coverage + +The test suite aims to cover: + +- āœ… Input parsing and validation +- āœ… CCS calibration (per-charge and global) +- āœ… Single-output and multi-output predictions +- āœ… Model loading from various checkpoint formats +- āœ… Command-line interface and argument handling +- āœ… Error handling and custom exceptions +- āœ… Data consistency across pipeline +- āœ… Edge cases (single peptide, modified peptides, high charges, etc.) + +## Notes + +- Some integration tests require trained models and are skipped by default +- Tests use mocking for external dependencies (PyTorch Lightning, DeepLC) +- Multi-output prediction tests verify proper handling of tuple outputs +- Calibration tests verify broadcasting for both single and multi-output cases + +## Adding New Tests + +When adding new tests: + +1. Use appropriate fixtures from `conftest.py` +2. Group related tests in classes +3. Use descriptive test names starting with `test_` +4. Add docstrings explaining what each test verifies +5. Use `@pytest.mark.integration` for tests requiring trained models +6. Mock external dependencies when possible +7. Test both success and failure cases diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..4f46746 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for IM2Deep package.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0b08029 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,187 @@ +"""Pytest configuration and fixtures for IM2Deep tests.""" + +import pytest +import numpy as np +import pandas as pd +from pathlib import Path +from psm_utils import PSM, PSMList, Peptidoform + + +@pytest.fixture +def sample_psm_list(): + """Create a sample PSMList for testing.""" + psms = [ + PSM( + peptidoform=Peptidoform("PEPTIDE/2"), + spectrum_id="test_001", + run="test_run", + collection="test_collection", + is_decoy=False, + score=0.95, + ), + PSM( + peptidoform=Peptidoform("SEQUENCE/3"), + spectrum_id="test_002", + run="test_run", + collection="test_collection", + is_decoy=False, + score=0.92, + ), + PSM( + peptidoform=Peptidoform("TESTPEPTIDE/2"), + spectrum_id="test_003", + run="test_run", + collection="test_collection", + is_decoy=False, + score=0.88, + ), + ] + return PSMList(psm_list=psms) + + +@pytest.fixture +def sample_psm_list_with_ccs(): + """Create a sample PSMList with CCS values for calibration testing.""" + psms = [ + PSM( + peptidoform=Peptidoform("PEPTIDE/2"), + spectrum_id="cal_001", + run="cal_run", + collection="cal_collection", + is_decoy=False, + score=0.95, + retention_time=100.5, + metadata={"CCS": float(450.5)}, + ), + PSM( + peptidoform=Peptidoform("SEQUENCE/3"), + spectrum_id="cal_002", + run="cal_run", + collection="cal_collection", + is_decoy=False, + score=0.92, + retention_time=120.3, + metadata={"CCS": float(520.8)}, + ), + PSM( + peptidoform=Peptidoform("TESTPEPTIDE/2"), + spectrum_id="cal_003", + run="cal_run", + collection="cal_collection", + is_decoy=False, + score=0.88, + retention_time=135.7, + metadata={"CCS": float(480.2)}, + ), + PSM( + peptidoform=Peptidoform("ANOTHER/3"), + spectrum_id="cal_004", + run="cal_run", + collection="cal_collection", + is_decoy=False, + score=0.90, + retention_time=142.1, + metadata={"CCS": float(510.5)}, + ), + ] + return PSMList(psm_list=psms) + + +@pytest.fixture +def sample_reference_psm_list(): + """Create a sample reference PSMList for calibration.""" + psms = [ + PSM( + peptidoform=Peptidoform("PEPTIDE/2"), + spectrum_id="ref_001", + run="ref_run", + collection="ref_collection", + is_decoy=False, + metadata={"CCS": 455.0}, + ), + PSM( + peptidoform=Peptidoform("SEQUENCE/3"), + spectrum_id="ref_002", + run="ref_run", + collection="ref_collection", + is_decoy=False, + metadata={"CCS": 525.0}, + ), + PSM( + peptidoform=Peptidoform("TESTPEPTIDE/2"), + spectrum_id="ref_003", + run="ref_run", + collection="ref_collection", + is_decoy=False, + metadata={"CCS": 485.0}, + ), + PSM( + peptidoform=Peptidoform("REFERENCE/4"), + spectrum_id="ref_004", + run="ref_run", + collection="ref_collection", + is_decoy=False, + metadata={"CCS": 600.0}, + ), + ] + return PSMList(psm_list=psms) + + +@pytest.fixture +def sample_peptidoforms(): + """Create sample peptidoforms list.""" + return [ + Peptidoform("PEPTIDE/2"), + Peptidoform("SEQUENCE/3"), + Peptidoform("TESTPEPTIDE/2"), + ] + + +@pytest.fixture +def sample_ccs_values(): + """Create sample CCS values array.""" + return np.array([450.5, 520.8, 480.2], dtype=np.float32) + + +@pytest.fixture +def sample_predicted_ccs(): + """Create sample predicted CCS values.""" + return np.array([448.0, 516.0, 478.0], dtype=np.float32) + + +@pytest.fixture +def sample_predicted_ccs_multi(): + """Create sample multi-conformer predicted CCS values.""" + return np.array([[448.0, 452.0], [516.0, 524.0], [478.0, 482.0]], dtype=np.float32) + + +@pytest.fixture +def temp_model_path(tmp_path): + """Create a temporary model file path.""" + return tmp_path / "test_model.ckpt" + + +@pytest.fixture +def sample_legacy_format_df(): + """Create a sample DataFrame in legacy format.""" + return pd.DataFrame( + { + "seq": ["PEPTIDE", "SEQUENCE", "TESTPEPTIDE"], + "modifications": ["", "", ""], + "charge": [2, 3, 2], + "CCS": [450.5, 520.8, 480.2], + } + ) + + +@pytest.fixture +def sample_peprec_format_df(): + """Create a sample DataFrame in PEPREC format.""" + return pd.DataFrame( + { + "spec_id": ["test_001", "test_002", "test_003"], + "peptide": ["PEPTIDE", "SEQUENCE", "TESTPEPTIDE"], + "modifications": ["", "", ""], + "charge": [2, 3, 2], + } + ) diff --git a/tests/test_calibration.py b/tests/test_calibration.py new file mode 100644 index 0000000..bb3e286 --- /dev/null +++ b/tests/test_calibration.py @@ -0,0 +1,322 @@ +"""Tests for calibration module.""" + +import pytest +import numpy as np +import pandas as pd +from psm_utils import Peptidoform, PSM, PSMList + +from im2deep.calibration import LinearCCSCalibration, get_default_reference +from im2deep._exceptions import CalibrationError + + +class TestLinearCCSCalibration: + """Tests for LinearCCSCalibration class.""" + + def test_init_default(self): + """Test initialization with default parameters.""" + calibration = LinearCCSCalibration() + assert calibration.per_charge is True + assert calibration.use_charge_state is None + assert calibration.is_fitted is False + assert calibration.charge_shifts == {} + assert calibration.general_shift is None + + def test_init_custom(self): + """Test initialization with custom parameters.""" + calibration = LinearCCSCalibration(per_charge=False, use_charge_state=3) + assert calibration.per_charge is False + assert calibration.use_charge_state == 3 + assert calibration.is_fitted is False + + def test_fit_per_charge(self, sample_peptidoforms, sample_ccs_values, sample_predicted_ccs): + """Test fitting with per-charge calibration.""" + calibration = LinearCCSCalibration(per_charge=True) + + # Create DataFrames for target and source + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame({"peptidoform": sample_peptidoforms, "CCS": sample_predicted_ccs}) + + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + assert calibration.is_fitted is True + assert len(calibration.charge_shifts) > 0 + assert calibration.general_shift is not None + + def test_fit_global(self, sample_peptidoforms, sample_ccs_values, sample_predicted_ccs): + """Test fitting with global calibration.""" + calibration = LinearCCSCalibration(per_charge=False, use_charge_state=2) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame({"peptidoform": sample_peptidoforms, "CCS": sample_predicted_ccs}) + + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + assert calibration.is_fitted is True + assert calibration.general_shift is not None + assert isinstance(calibration.general_shift, float) + + def test_transform_single_output( + self, sample_peptidoforms, sample_ccs_values, sample_predicted_ccs + ): + """Test transforming single-output predictions.""" + calibration = LinearCCSCalibration(per_charge=True) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame({"peptidoform": sample_peptidoforms, "CCS": sample_predicted_ccs}) + + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + # Transform with predictions in metadata + transform_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [ + {"predicted_CCS_uncalibrated": pred} for pred in sample_predicted_ccs + ], + } + ) + + calibrated = calibration.transform(transform_df) + + assert len(calibrated) == len(sample_predicted_ccs) + assert isinstance(calibrated, np.ndarray) + + def test_transform_multi_output( + self, sample_peptidoforms, sample_ccs_values, sample_predicted_ccs_multi + ): + """Test transforming multi-output predictions.""" + calibration = LinearCCSCalibration(per_charge=True) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame( + {"peptidoform": sample_peptidoforms, "CCS": sample_ccs_values - 2.0} # Simulate shift + ) + + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + # Transform multi-output with arrays in metadata + transform_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [ + {"predicted_CCS_uncalibrated": pred} for pred in sample_predicted_ccs_multi + ], + } + ) + + calibrated = calibration.transform(transform_df) + + assert len(calibrated) == len(sample_predicted_ccs_multi) + assert isinstance(calibrated, np.ndarray) + # Check that arrays are preserved for multiconformer + assert isinstance(calibrated[0], np.ndarray) + assert len(calibrated[0]) == 2 # Two conformers + + def test_transform_not_fitted(self, sample_peptidoforms, sample_predicted_ccs): + """Test transform raises error when not fitted.""" + calibration = LinearCCSCalibration() + + transform_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [ + {"predicted_CCS_uncalibrated": pred} for pred in sample_predicted_ccs + ], + } + ) + + with pytest.raises(CalibrationError, match="not been fitted"): + calibration.transform(transform_df) + + def test_calculate_ccs_shift_no_overlap(self): + """Test shift calculation with no overlapping peptides.""" + calibration = LinearCCSCalibration(per_charge=False, use_charge_state=2) + + target_df = pd.DataFrame( + {"peptidoform": [Peptidoform("PEPTIDE/2")], "metadata": [{"CCS": 450.0}]} + ) + + source_df = pd.DataFrame({"peptidoform": [Peptidoform("DIFFERENT/2")], "CCS": [460.0]}) + + shift = calibration.calculate_ccs_shift(target_df, source_df) + + assert shift == 0.0 # No overlap returns 0.0 + + def test_calculate_ccs_shift_with_overlap(self): + """Test shift calculation with overlapping peptides.""" + calibration = LinearCCSCalibration(per_charge=False, use_charge_state=2) + + peptidoforms = [Peptidoform("PEPTIDE/2"), Peptidoform("SEQUENCE/2")] + + target_df = pd.DataFrame( + {"peptidoform": peptidoforms, "metadata": [{"CCS": 450.0}, {"CCS": 520.0}]} + ) + + source_df = pd.DataFrame({"peptidoform": peptidoforms, "CCS": [445.0, 515.0]}) + + shift = calibration.calculate_ccs_shift(target_df, source_df) + + assert isinstance(shift, float) + assert abs(shift - 5.0) < 0.1 # Should be approximately 5.0 + + def test_compute_ccs_shift_per_charge(self): + """Test per-charge shift computation.""" + peptidoforms = [ + Peptidoform("PEPTIDE/2"), + Peptidoform("SEQUENCE/3"), + Peptidoform("TEST/2"), + ] + + target_df = pd.DataFrame( + { + "peptidoform": peptidoforms, + "metadata": [{"CCS": 450.0}, {"CCS": 520.0}, {"CCS": 480.0}], + } + ) + + source_df = pd.DataFrame({"peptidoform": peptidoforms, "CCS": [445.0, 515.0, 475.0]}) + + shifts = LinearCCSCalibration._compute_ccs_shift_per_charge(target_df, source_df) + + assert isinstance(shifts, dict) + assert 2 in shifts + assert 3 in shifts + assert abs(shifts[2] - 5.0) < 0.1 + assert abs(shifts[3] - 5.0) < 0.1 + + def test_fit_with_missing_charges(self, sample_peptidoforms, sample_ccs_values): + """Test that missing charges are filled with general shift.""" + calibration = LinearCCSCalibration(per_charge=True) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame( + {"peptidoform": sample_peptidoforms, "CCS": sample_ccs_values - 5.0} + ) + + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + # Check that charges 1-6 are all filled + for charge in range(1, 7): + assert charge in calibration.charge_shifts + assert isinstance(calibration.charge_shifts[charge], float) + + def test_fit_invalid_charge_state(self, sample_peptidoforms, sample_ccs_values): + """Test that invalid charge state raises error.""" + calibration = LinearCCSCalibration(per_charge=False, use_charge_state=10) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame({"peptidoform": sample_peptidoforms, "CCS": sample_ccs_values}) + + with pytest.raises(CalibrationError, match="Invalid charge state"): + calibration.calculate_ccs_shift(target_df, source_df) + + def test_shift_broadcasting(self, sample_peptidoforms): + """Test that shifts broadcast correctly for multi-output.""" + calibration = LinearCCSCalibration(per_charge=True) + + # Manually set charge shifts + calibration.charge_shifts = {2: 5.0, 3: 3.0} + calibration.general_shift = 4.0 + calibration.fitted = True + + # Test single output + single_pred = np.array([450.0, 520.0, 480.0], dtype=np.float32) + single_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"predicted_CCS_uncalibrated": pred} for pred in single_pred], + } + ) + single_cal = calibration.transform(single_df) + assert len(single_cal) == 3 + + # Test multi output + multi_pred = np.array([[450.0, 452.0], [520.0, 524.0], [480.0, 482.0]], dtype=np.float32) + multi_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"predicted_CCS_uncalibrated": pred} for pred in multi_pred], + } + ) + multi_cal = calibration.transform(multi_df) + assert len(multi_cal) == 3 + # Check arrays are preserved + assert isinstance(multi_cal[0], np.ndarray) + assert len(multi_cal[0]) == 2 + + def test_get_default_reference(self): + """Test loading default reference dataset.""" + try: + reference_df = get_default_reference(multi=False) + assert isinstance(reference_df, pd.DataFrame) + assert "peptidoform" in reference_df.columns + assert "CCS" in reference_df.columns + assert len(reference_df) > 0 + except FileNotFoundError: + pytest.skip("Default reference dataset not found") + + def test_large_shift_warning(self, caplog): + """Test that large shifts trigger a warning.""" + target_df = pd.DataFrame( + {"peptidoform": [Peptidoform("PEPTIDE/2")], "metadata": [{"CCS": 450.0}]} + ) + + source_df = pd.DataFrame( + {"peptidoform": [Peptidoform("PEPTIDE/2")], "CCS": [300.0]} # Large difference + ) + + shift = LinearCCSCalibration._compute_ccs_shift(target_df, source_df, 2) + + assert abs(shift) > 100 + assert any("unusually large" in record.message.lower() for record in caplog.records) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..ec2f809 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,244 @@ +"""Tests for CLI module.""" + +import pytest +from click.testing import CliRunner +from unittest.mock import patch, MagicMock +import tempfile +from pathlib import Path + +from im2deep.__main__ import cli, predict + + +class TestCLI: + """Tests for command-line interface.""" + + @pytest.fixture + def runner(self): + """Create a CLI runner.""" + return CliRunner() + + @pytest.fixture + def temp_input_file(self): + """Create a temporary input file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write("seq,modifications,charge\n") + f.write("PEPTIDE,,2\n") + f.write("SEQUENCE,,3\n") + yield Path(f.name) + + @pytest.fixture + def temp_cal_file(self): + """Create a temporary calibration file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write("seq,modifications,charge,CCS\n") + f.write("PEPTIDE,,2,450.5\n") + f.write("TESTPEP,,2,480.2\n") + yield Path(f.name) + + def test_cli_help(self, runner): + """Test CLI help message.""" + result = runner.invoke(cli, ["--help"]) + assert result.exit_code == 0 + assert "IM2Deep" in result.output or "predict" in result.output + + def test_cli_version(self, runner): + """Test CLI version flag.""" + result = runner.invoke(cli, ["--version"]) + assert result.exit_code == 0 + + @patch("im2deep.core.predict_and_calibrate") + def test_predict_command_basic(self, mock_predict, runner, temp_input_file, temp_cal_file): + """Test basic predict command.""" + mock_predict.return_value = MagicMock() + + result = runner.invoke( + cli, + [ + "predict", + str(temp_input_file), + "--calibration-precursors", + str(temp_cal_file), + ], + ) + + # Check that it doesn't crash + assert "error" not in result.output.lower() or result.exit_code == 0 + + @patch("im2deep.core.predict_and_calibrate") + def test_predict_command_with_output( + self, mock_predict, runner, temp_input_file, temp_cal_file, tmp_path + ): + """Test predict command with output file.""" + mock_predict.return_value = MagicMock() + output_file = tmp_path / "output.csv" + + result = runner.invoke( + cli, + [ + "predict", + str(temp_input_file), + "--calibration-precursors", + str(temp_cal_file), + "--output-file", + str(output_file), + ], + ) + + # Should create output file or at least not crash + assert result.exit_code in [0, 1] # May fail due to mocking + + @patch("im2deep.core.predict_and_calibrate") + def test_predict_default_command(self, mock_predict, runner, temp_input_file, temp_cal_file): + """Test that predict is the default command.""" + mock_predict.return_value = MagicMock() + + # Call without explicit 'predict' subcommand + result = runner.invoke( + cli, + [ + str(temp_input_file), + "--calibration-precursors", + str(temp_cal_file), + ], + ) + + # Should work as if 'predict' was specified + assert result.exit_code in [0, 1] + + def test_predict_missing_input(self, runner): + """Test predict command without input file.""" + result = runner.invoke(cli, ["predict"]) + assert result.exit_code != 0 + assert "Missing argument" in result.output or "required" in result.output.lower() + + def test_predict_logging_level(self, runner, temp_input_file): + """Test predict command with different logging levels.""" + for level in ["debug", "info", "warning", "error"]: + result = runner.invoke( + cli, ["predict", str(temp_input_file), "--logging-level", level] + ) + # Should at least parse the argument + assert "Invalid value" not in result.output + + @patch("im2deep.core.predict_and_calibrate") + def test_predict_multi_flag(self, mock_predict, runner, temp_input_file, temp_cal_file): + """Test predict command with multi-conformer flag.""" + mock_predict.return_value = MagicMock() + + result = runner.invoke( + cli, + [ + "predict", + str(temp_input_file), + "--calibration-precursors", + str(temp_cal_file), + "--multi", + ], + ) + + # Check that multi flag is recognized + assert result.exit_code in [0, 1] + + @patch("im2deep.core.predict_and_calibrate") + def test_predict_per_charge_calibration( + self, mock_predict, runner, temp_input_file, temp_cal_file + ): + """Test predict command with per-charge calibration.""" + mock_predict.return_value = MagicMock() + + # Test with per-charge enabled (default is True, so just don't pass the flag) + result = runner.invoke( + cli, + [ + "predict", + str(temp_input_file), + "--calibration-precursors", + str(temp_cal_file), + ], + ) + + assert result.exit_code in [0, 1] + + @patch("im2deep.core.predict_and_calibrate") + def test_predict_global_calibration( + self, mock_predict, runner, temp_input_file, temp_cal_file + ): + """Test predict command with global calibration (per-charge disabled).""" + mock_predict.return_value = MagicMock() + + # Test with per-charge disabled + result = runner.invoke( + cli, + [ + "predict", + str(temp_input_file), + "--calibration-precursors", + str(temp_cal_file), + "--calibrate-per-charge", + "false", + ], + ) + + assert result.exit_code in [0, 1] + + def test_train_command_not_available(self, runner): + """Test that train command is not currently available.""" + result = runner.invoke(cli, ["train", "--help"]) + # Train is commented out, so this should fail + assert result.exit_code != 0 or "train" not in result.output.lower() + + +class TestSetupLogging: + """Tests for setup_logging function.""" + + def test_setup_logging_default(self): + """Test setup_logging with info level.""" + from im2deep.utils import setup_logging + import logging + + setup_logging("info") + + logger = logging.getLogger("im2deep") + assert logger.level == logging.INFO + + def test_setup_logging_debug(self): + """Test setup_logging with debug level.""" + from im2deep.utils import setup_logging + import logging + + setup_logging("debug") + + logger = logging.getLogger("im2deep") + assert logger.level == logging.DEBUG + + def test_setup_logging_warning(self): + """Test setup_logging with warning level.""" + from im2deep.utils import setup_logging + import logging + + setup_logging("warning") + + logger = logging.getLogger("im2deep") + assert logger.level == logging.WARNING + + def test_setup_logging_affects_submodules(self): + """Test that setup_logging affects all im2deep submodules.""" + from im2deep.utils import setup_logging + import logging + + setup_logging("debug") + + # Check root logger is set to debug + root_logger = logging.getLogger() + assert root_logger.level == logging.DEBUG + + +class TestDefaultCommandGroup: + """Tests for DefaultCommandGroup.""" + + def test_default_command_group_import(self): + """Test that DefaultCommandGroup can be imported.""" + from im2deep.utils import DefaultCommandGroup + import click + + assert issubclass(DefaultCommandGroup, click.Group) diff --git a/tests/test_constants.py b/tests/test_constants.py new file mode 100644 index 0000000..89ac6e3 --- /dev/null +++ b/tests/test_constants.py @@ -0,0 +1,74 @@ +"""Tests for constants module.""" + +import pytest +from pathlib import Path + + +class TestConstants: + """Tests for module constants.""" + + def test_default_model_path_exists(self): + """Test that DEFAULT_MODEL constant points to existing file.""" + from im2deep.constants import DEFAULT_MODEL + + if DEFAULT_MODEL is not None: + model_path = Path(DEFAULT_MODEL) + # Check if path is valid (may not exist in test environment) + assert isinstance(DEFAULT_MODEL, (str, Path)) + + def test_default_multi_model_path_exists(self): + """Test that DEFAULT_MULTI_MODEL constant points to existing file.""" + from im2deep.constants import DEFAULT_MULTI_MODEL + + if DEFAULT_MULTI_MODEL is not None: + model_path = Path(DEFAULT_MULTI_MODEL) + assert isinstance(DEFAULT_MULTI_MODEL, (str, Path)) + + def test_default_reference_dataset_path_exists(self): + """Test that default reference dataset path exists.""" + from im2deep.constants import DEFAULT_REFERENCE_DATASET_PATH + + if DEFAULT_REFERENCE_DATASET_PATH is not None: + dataset_path = Path(DEFAULT_REFERENCE_DATASET_PATH) + assert isinstance(DEFAULT_REFERENCE_DATASET_PATH, (str, Path)) + + def test_default_multi_reference_dataset_path_exists(self): + """Test that default multi reference dataset path exists.""" + from im2deep.constants import DEFAULT_MULTI_REFERENCE_DATASET_PATH + + if DEFAULT_MULTI_REFERENCE_DATASET_PATH is not None: + dataset_path = Path(DEFAULT_MULTI_REFERENCE_DATASET_PATH) + assert isinstance(DEFAULT_MULTI_REFERENCE_DATASET_PATH, (str, Path)) + + def test_default_config_exists(self): + """Test that DEFAULT_CONFIG constant exists.""" + from im2deep.constants import DEFAULT_CONFIG + + assert isinstance(DEFAULT_CONFIG, dict) + assert len(DEFAULT_CONFIG) > 0 + + def test_default_multi_config_exists(self): + """Test that DEFAULT_MULTI_CONFIG constant exists.""" + from im2deep.constants import DEFAULT_MULTI_CONFIG + + assert isinstance(DEFAULT_MULTI_CONFIG, dict) + assert len(DEFAULT_MULTI_CONFIG) > 0 + + def test_config_has_required_keys(self): + """Test that config dictionaries have required keys.""" + from im2deep.constants import DEFAULT_CONFIG + + # Check for common required keys + # (actual keys depend on model architecture) + assert isinstance(DEFAULT_CONFIG, dict) + + def test_constants_are_immutable(self): + """Test that constants should not be modified.""" + from im2deep import constants + + # Store original values + original_model = constants.DEFAULT_MODEL + + # Try to modify (this is just checking the pattern, not enforcement) + # In Python, constants are by convention, not enforced + assert hasattr(constants, "DEFAULT_MODEL") diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..2c0ab3b --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,210 @@ +"""Tests for core module.""" + +import pytest +import numpy as np +import torch +from unittest.mock import Mock, patch, MagicMock + +from im2deep import core +from im2deep._exceptions import IM2DeepError + + +class TestPredict: + """Tests for predict function.""" + + @patch("im2deep.core._model_ops.predict") + @patch("im2deep.core.DeepLCDataset") + def test_predict_basic(self, mock_dataset, mock_predict, sample_psm_list): + """Test basic prediction.""" + mock_dataset.from_psm_list.return_value = MagicMock() + mock_predict.return_value = torch.tensor([450.0, 520.0, 480.0]) + + predictions = core.predict(sample_psm_list) + + assert isinstance(predictions, np.ndarray) + assert len(predictions) == 3 + mock_dataset.from_psm_list.assert_called_once() + mock_predict.assert_called_once() + + @patch("im2deep.core._model_ops.predict") + @patch("im2deep.core.DeepLCDataset") + def test_predict_with_model(self, mock_dataset, mock_predict, sample_psm_list): + """Test prediction with custom model.""" + mock_dataset.from_psm_list.return_value = MagicMock() + mock_predict.return_value = torch.tensor([450.0, 520.0, 480.0]) + + custom_model = "custom_model.ckpt" + predictions = core.predict(sample_psm_list, model=custom_model) + + assert isinstance(predictions, np.ndarray) + # Check that custom model was passed + call_kwargs = mock_predict.call_args[1] + assert "model" in call_kwargs + + @patch("im2deep.core._model_ops.predict") + @patch("im2deep.core.DeepLCDataset") + def test_predict_multi(self, mock_dataset, mock_predict, sample_psm_list): + """Test prediction with multi-output model.""" + mock_dataset.from_psm_list.return_value = MagicMock() + mock_predict.return_value = torch.tensor([[450.0, 452.0], [520.0, 524.0], [480.0, 482.0]]) + + predictions = core.predict(sample_psm_list, multi=True) + + assert isinstance(predictions, np.ndarray) + assert predictions.shape == (3, 2) + + @patch("im2deep.core._model_ops.predict") + @patch("im2deep.core.DeepLCDataset") + def test_predict_with_kwargs(self, mock_dataset, mock_predict, sample_psm_list): + """Test prediction with additional kwargs.""" + mock_dataset.from_psm_list.return_value = MagicMock() + mock_predict.return_value = torch.tensor([450.0, 520.0, 480.0]) + + predictions = core.predict( + sample_psm_list, + predict_kwargs={"batch_size": 256, "device": "cpu"}, + ) + + assert isinstance(predictions, np.ndarray) + call_kwargs = mock_predict.call_args[1] + assert call_kwargs["batch_size"] == 256 + assert call_kwargs["device"] == "cpu" + + def test_predict_invalid_psm_list(self): + """Test prediction with invalid PSMList.""" + with pytest.raises(IM2DeepError): + core.predict([1, 2, 3]) + + +class TestPredictAndCalibrate: + """Tests for predict_and_calibrate function.""" + + @patch("im2deep.core.predict") + @patch("im2deep.core.LinearCCSCalibration") + def test_predict_and_calibrate_basic( + self, + mock_calibration_class, + mock_predict, + sample_psm_list, + sample_psm_list_with_ccs, + ): + """Test basic predict and calibrate.""" + # Mock predict to return predictions + mock_predict.return_value = np.array([448.0, 516.0, 478.0]) + + # Mock calibration + mock_calibration = MagicMock() + mock_calibration.is_fitted = False + mock_calibration.transform.return_value = np.array([450.0, 520.0, 480.0]) + mock_calibration_class.return_value = mock_calibration + + predictions = core.predict_and_calibrate( + psm_list=sample_psm_list, + psm_list_cal=sample_psm_list_with_ccs, + ) + + assert isinstance(predictions, np.ndarray) + assert len(predictions) == 3 + mock_calibration.fit.assert_called_once() + mock_calibration.transform.assert_called_once() + + @patch("im2deep.core.predict") + @patch("im2deep.core.LinearCCSCalibration") + def test_predict_and_calibrate_with_reference( + self, + mock_calibration_class, + mock_predict, + sample_psm_list, + sample_psm_list_with_ccs, + sample_reference_psm_list, + ): + """Test predict and calibrate with reference PSMList.""" + mock_predict.return_value = np.array([448.0, 516.0, 478.0]) + + mock_calibration = MagicMock() + mock_calibration.is_fitted = False + mock_calibration.transform.return_value = np.array([450.0, 520.0, 480.0]) + mock_calibration_class.return_value = mock_calibration + + predictions = core.predict_and_calibrate( + psm_list=sample_psm_list, + psm_list_cal=sample_psm_list_with_ccs, + psm_list_reference=sample_reference_psm_list, + ) + + assert isinstance(predictions, np.ndarray) + + @patch("im2deep.core.predict") + def test_predict_and_calibrate_custom_calibration( + self, + mock_predict, + sample_psm_list, + sample_psm_list_with_ccs, + ): + """Test predict and calibrate with custom calibration.""" + from im2deep.calibration import Calibration + + # Create a real mock class that inherits from Calibration + class MockCalibration(Calibration): + def __init__(self): + self._is_fitted = False + self.fit_called = False + self.transform_called = False + + @property + def is_fitted(self): + return self._is_fitted + + def fit(self, *args, **kwargs): + self._is_fitted = True + self.fit_called = True + + def transform(self, *args, **kwargs): + self.transform_called = True + return np.array([450.0, 520.0, 480.0]) + + mock_predict.return_value = np.array([448.0, 516.0, 478.0]) + + custom_calibration = MockCalibration() + + predictions = core.predict_and_calibrate( + psm_list=sample_psm_list, + psm_list_cal=sample_psm_list_with_ccs, + calibration=custom_calibration, + ) + + assert isinstance(predictions, np.ndarray) + assert custom_calibration.fit_called + assert custom_calibration.transform_called + + @patch("im2deep.core.predict") + def test_predict_and_calibrate_multi_output( + self, mock_predict, sample_psm_list, sample_psm_list_with_ccs + ): + """Test predict and calibrate with multi-output predictions.""" + mock_predict.return_value = np.array([[448.0, 452.0], [516.0, 524.0], [478.0, 482.0]]) + + with patch("im2deep.core.LinearCCSCalibration") as mock_cal_class: + mock_calibration = MagicMock() + mock_calibration.is_fitted = False + mock_calibration.transform.return_value = np.array( + [[450.0, 454.0], [520.0, 528.0], [480.0, 484.0]] + ) + mock_cal_class.return_value = mock_calibration + + predictions = core.predict_and_calibrate( + psm_list=sample_psm_list, + psm_list_cal=sample_psm_list_with_ccs, + multi=True, + ) + + assert isinstance(predictions, np.ndarray) + assert predictions.shape == (3, 2) + + def test_predict_and_calibrate_invalid_cal_psm(self, sample_psm_list): + """Test that calibration PSMList must have CCS values.""" + with pytest.raises(IM2DeepError): + core.predict_and_calibrate( + psm_list=sample_psm_list, + psm_list_cal=sample_psm_list, # Missing CCS values + ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..dc717cb --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,50 @@ +"""Tests for exceptions module.""" + +import pytest +from im2deep._exceptions import ( + IM2DeepError, + CalibrationError, +) + + +class TestExceptions: + """Tests for custom exception classes.""" + + def test_im2deep_error(self): + """Test IM2DeepError can be raised and caught.""" + with pytest.raises(IM2DeepError, match="test error"): + raise IM2DeepError("test error") + + def test_calibration_error(self): + """Test CalibrationError inherits from IM2DeepError.""" + with pytest.raises(IM2DeepError): + raise CalibrationError("calibration failed") + + with pytest.raises(CalibrationError, match="calibration failed"): + raise CalibrationError("calibration failed") + + def test_exception_inheritance(self): + """Test that CalibrationError inherits from IM2DeepError.""" + assert issubclass(CalibrationError, IM2DeepError) + assert issubclass(CalibrationError, Exception) + + def test_exception_with_cause(self): + """Test exceptions can wrap other exceptions.""" + original_error = ValueError("original error") + + with pytest.raises(CalibrationError) as exc_info: + try: + raise original_error + except ValueError as e: + raise CalibrationError("wrapped error") from e + + assert exc_info.value.__cause__ is original_error + + def test_exception_messages(self): + """Test that exception messages are preserved.""" + message = "detailed error message with context" + + try: + raise IM2DeepError(message) + except IM2DeepError as e: + assert str(e) == message diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..a5df236 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,333 @@ +"""Integration tests for IM2Deep package.""" + +import pytest +import numpy as np +import pandas as pd +from pathlib import Path +from psm_utils import PSM, PSMList, Peptidoform + +from im2deep import core +from im2deep.calibration import LinearCCSCalibration +from im2deep._exceptions import IM2DeepError + + +class TestEndToEndWorkflow: + """Integration tests for end-to-end workflows.""" + + @pytest.mark.integration + @pytest.mark.skipif(True, reason="Requires trained models") + def test_predict_workflow(self, sample_psm_list): + """Test complete prediction workflow.""" + # This would require actual trained models + predictions = core.predict(sample_psm_list) + + assert isinstance(predictions, np.ndarray) + assert len(predictions) == len(sample_psm_list) + assert np.all(predictions > 0) # CCS values should be positive + + @pytest.mark.integration + @pytest.mark.skipif(True, reason="Requires trained models") + def test_predict_and_calibrate_workflow(self, sample_psm_list, sample_psm_list_with_ccs): + """Test complete prediction and calibration workflow.""" + predictions = core.predict_and_calibrate( + psm_list=sample_psm_list, + psm_list_cal=sample_psm_list_with_ccs, + ) + + assert isinstance(predictions, np.ndarray) + assert len(predictions) == len(sample_psm_list) + assert np.all(predictions > 0) + + def test_calibration_workflow( + self, + sample_peptidoforms, + sample_ccs_values, + sample_predicted_ccs, + ): + """Test complete calibration workflow without prediction.""" + calibration = LinearCCSCalibration(per_charge=True) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame({"peptidoform": sample_peptidoforms, "CCS": sample_predicted_ccs}) + + # Fit calibration + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + assert calibration.is_fitted + + # Transform predictions + transform_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [ + {"predicted_CCS_uncalibrated": pred} for pred in sample_predicted_ccs + ], + } + ) + + calibrated = calibration.transform(transform_df) + + assert len(calibrated) == len(sample_predicted_ccs) + # All values should be positive (scalars or arrays) + for val in calibrated: + if isinstance(val, np.ndarray): + assert np.all(val > 0) + else: + assert val > 0 + + # Calibrated values should be closer to targets (compare scalars) + calibrated_scalars = np.array( + [v if not isinstance(v, np.ndarray) else v[0] for v in calibrated] + ) + original_error = np.mean(np.abs(sample_predicted_ccs - sample_ccs_values)) + calibrated_error = np.mean(np.abs(calibrated_scalars - sample_ccs_values)) + assert calibrated_error <= original_error + + def test_multi_output_calibration_workflow( + self, sample_peptidoforms, sample_ccs_values, sample_predicted_ccs_multi + ): + """Test calibration workflow with multi-output predictions.""" + calibration = LinearCCSCalibration(per_charge=True) + + target_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [{"CCS": ccs} for ccs in sample_ccs_values], + } + ) + + source_df = pd.DataFrame( + {"peptidoform": sample_peptidoforms, "CCS": sample_ccs_values - 2.0} + ) + + # Fit with single output targets + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + # Transform multi-output predictions + transform_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [ + {"predicted_CCS_uncalibrated": pred} for pred in sample_predicted_ccs_multi + ], + } + ) + + calibrated = calibration.transform(transform_df) + + assert len(calibrated) == len(sample_predicted_ccs_multi) + # Check that arrays are preserved + for val in calibrated: + assert isinstance(val, np.ndarray) + assert len(val) == 2 + assert np.all(val > 0) + + @pytest.mark.integration + def test_file_parsing_to_prediction(self, tmp_path): + """Test complete workflow from file parsing to prediction.""" + from im2deep.utils import parse_input + + # Create test file + input_file = tmp_path / "input.csv" + with open(input_file, "w") as f: + f.write("seq,modifications,charge\n") + f.write("PEPTIDE,,2\n") + f.write("SEQUENCE,,3\n") + + # Parse input + psm_list = parse_input(input_file) + + assert isinstance(psm_list, PSMList) + assert len(psm_list) == 2 + + def test_charge_state_coverage(self): + """Test that calibration covers all relevant charge states.""" + # Create peptides with various charge states + peptidoforms = [ + Peptidoform("PEPTIDE/1"), + Peptidoform("PEPTIDE/2"), + Peptidoform("PEPTIDE/3"), + Peptidoform("PEPTIDE/4"), + Peptidoform("PEPTIDE/5"), + ] + ccs_target = np.array([300.0, 400.0, 500.0, 600.0, 700.0]) + ccs_source = ccs_target - 5.0 + + target_df = pd.DataFrame( + {"peptidoform": peptidoforms, "metadata": [{"CCS": ccs} for ccs in ccs_target]} + ) + + source_df = pd.DataFrame({"peptidoform": peptidoforms, "CCS": ccs_source}) + + calibration = LinearCCSCalibration(per_charge=True) + calibration.fit( + psm_df_target=target_df, + psm_df_source=source_df, + ) + + # All charges 1-6 should be covered (including extrapolation) + assert all(c in calibration.charge_shifts for c in range(1, 7)) + + def test_error_propagation(self, sample_psm_list): + """Test that errors propagate correctly through the workflow.""" + # Invalid PSMList should raise IM2DeepError + with pytest.raises(IM2DeepError): + core.predict(None) + + # Empty PSMList should raise error + with pytest.raises(IM2DeepError): + core.predict(PSMList(psm_list=[])) + + +class TestDataConsistency: + """Tests for data consistency across the pipeline.""" + + def test_psm_list_preservation(self, sample_psm_list): + """Test that PSMList properties are preserved through processing.""" + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(sample_psm_list) + + # Check that all PSMs are preserved + assert len(validated) == len(sample_psm_list) + + # Check that peptidoforms are preserved + for orig, val in zip(sample_psm_list, validated): + assert orig.peptidoform == val.peptidoform + + def test_ccs_value_consistency(self, sample_psm_list_with_ccs): + """Test that CCS values remain consistent.""" + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(sample_psm_list_with_ccs, needs_target=True) + + for orig, val in zip(sample_psm_list_with_ccs, validated): + assert orig.metadata["CCS"] == val.metadata["CCS"] + + def test_array_shape_consistency(self, sample_peptidoforms, sample_predicted_ccs): + """Test that array shapes remain consistent.""" + calibration = LinearCCSCalibration() + + # Set up a simple calibration + calibration.charge_shifts = {2: 5.0, 3: 3.0} + calibration.general_shift = 4.0 + calibration.fitted = True + + transform_df = pd.DataFrame( + { + "peptidoform": sample_peptidoforms, + "metadata": [ + {"predicted_CCS_uncalibrated": pred} for pred in sample_predicted_ccs + ], + } + ) + + result = calibration.transform(transform_df) + + assert len(result) == len(sample_predicted_ccs) + # Check that values are floats (not arrays for single output) + for val in result: + assert isinstance(val, (float, np.floating)) + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_single_peptide(self): + """Test prediction with single peptide.""" + psm = PSM( + peptidoform=Peptidoform("PEPTIDE/2"), + spectrum_id="test_001", + run="test_run", + is_decoy=False, + ) + psm_list = PSMList(psm_list=[psm]) + + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(psm_list) + + assert len(validated) == 1 + + def test_high_charge_state(self): + """Test handling of high charge states.""" + # Create PSMs with both valid and invalid charge states + psm_valid = PSM( + peptidoform=Peptidoform("PEPTIDE/3"), + spectrum_id="test_001", + run="test_run", + is_decoy=False, + ) + psm_high_charge = PSM( + peptidoform=Peptidoform("PEPTIDE/10"), + spectrum_id="test_002", + run="test_run", + is_decoy=False, + ) + psm_list = PSMList(psm_list=[psm_valid, psm_high_charge]) + + # Should filter out high charges (>6) but keep valid ones + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(psm_list) + # After filtering, only the valid charge state should remain + assert len(validated) == 1 + assert validated[0].peptidoform.precursor_charge == 3 + + def test_modified_peptides(self): + """Test handling of modified peptides.""" + psm = PSM( + peptidoform=Peptidoform("PEP[+15.99]TIDE/2"), + spectrum_id="test_001", + run="test_run", + is_decoy=False, + ) + psm_list = PSMList(psm_list=[psm]) + + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(psm_list) + assert len(validated) == 1 + + def test_very_long_peptide(self): + """Test handling of very long peptides.""" + long_seq = "A" * 100 + psm = PSM( + peptidoform=Peptidoform(f"{long_seq}/2"), + spectrum_id="test_001", + run="test_run", + is_decoy=False, + ) + psm_list = PSMList(psm_list=[psm]) + + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(psm_list) + assert len(validated) == 1 + + def test_very_short_peptide(self): + """Test handling of very short peptides.""" + psm = PSM( + peptidoform=Peptidoform("AA/2"), + spectrum_id="test_001", + run="test_run", + is_decoy=False, + ) + psm_list = PSMList(psm_list=[psm]) + + from im2deep.utils import validate_psm_list + + validated = validate_psm_list(psm_list) + assert len(validated) == 1 diff --git a/tests/test_model_ops.py b/tests/test_model_ops.py new file mode 100644 index 0000000..bbf5ea1 --- /dev/null +++ b/tests/test_model_ops.py @@ -0,0 +1,268 @@ +"""Tests for model operations module.""" + +import pytest +import torch +import numpy as np +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock + +from im2deep import _model_ops +from im2deep._exceptions import IM2DeepError + + +class TestLoadModel: + """Tests for load_model function.""" + + def test_load_model_from_path(self, temp_model_path): + """Test loading model from file path.""" + # Create a mock model + model = torch.nn.Linear(10, 1) + torch.save(model, temp_model_path) + + loaded_model = _model_ops.load_model(temp_model_path) + + assert isinstance(loaded_model, torch.nn.Module) + assert loaded_model.training is False # Should be in eval mode by default + + def test_load_model_from_module(self): + """Test loading model from existing module.""" + model = torch.nn.Linear(10, 1) + loaded_model = _model_ops.load_model(model) + + assert loaded_model is model + assert isinstance(loaded_model, torch.nn.Module) + + def test_load_model_none(self): + """Test loading model with None raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + _model_ops.load_model(None) + + def test_load_model_invalid_type(self): + """Test loading model with invalid type raises TypeError.""" + with pytest.raises(TypeError): + _model_ops.load_model(12345) + + def test_load_model_dict_checkpoint(self, temp_model_path): + """Test loading model from dict checkpoint.""" + model = torch.nn.Linear(10, 1) + checkpoint = {"model": model, "epoch": 10} + torch.save(checkpoint, temp_model_path) + + loaded_model = _model_ops.load_model(temp_model_path) + + assert isinstance(loaded_model, torch.nn.Module) + + def test_load_model_state_dict_only(self, temp_model_path): + """Test loading model with state_dict only raises NotImplementedError.""" + model = torch.nn.Linear(10, 1) + checkpoint = {"state_dict": model.state_dict()} + torch.save(checkpoint, temp_model_path) + + with pytest.raises(NotImplementedError, match="state_dict"): + _model_ops.load_model(temp_model_path) + + def test_load_model_device_cpu(self, temp_model_path): + """Test loading model on CPU.""" + model = torch.nn.Linear(10, 1) + torch.save(model, temp_model_path) + + loaded_model = _model_ops.load_model(temp_model_path, device="cpu") + + # Check that model is on CPU + assert next(loaded_model.parameters()).device.type == "cpu" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_load_model_device_cuda(self, temp_model_path): + """Test loading model on CUDA.""" + model = torch.nn.Linear(10, 1) + torch.save(model, temp_model_path) + + loaded_model = _model_ops.load_model(temp_model_path, device="cuda") + + assert next(loaded_model.parameters()).device.type == "cuda" + + +class TestPredict: + """Tests for predict function.""" + + @patch("im2deep._model_ops._get_architecture") + @patch("im2deep._model_ops._get_model_config") + @patch("im2deep._model_ops._get_loss_function") + def test_predict_single_output(self, mock_loss, mock_config, mock_arch, sample_psm_list): + """Test prediction with single-output model.""" + # Create mock model + mock_model_instance = MagicMock() + mock_model_instance.eval.return_value = None + mock_model_instance.return_value = torch.tensor([[450.0], [520.0], [480.0]]) + + mock_arch.return_value.load_from_checkpoint.return_value = mock_model_instance + mock_config.return_value = {} + mock_loss.return_value = torch.nn.L1Loss() + + # Create mock dataset + mock_dataset = MagicMock() + mock_dataset.__len__.return_value = 3 + mock_dataset.__getitem__.return_value = ( + [torch.randn(10), torch.randn(5)], + torch.tensor([0.0]), + ) + + with patch("im2deep._model_ops._predict_loop") as mock_predict_loop: + mock_predict_loop.return_value = torch.tensor([450.0, 520.0, 480.0]) + + predictions = _model_ops.predict( + model="fake_model.ckpt", + data=mock_dataset, + multi=False, + ) + + assert isinstance(predictions, torch.Tensor) + assert len(predictions) == 3 + + @patch("im2deep._model_ops._get_architecture") + @patch("im2deep._model_ops._get_model_config") + @patch("im2deep._model_ops._get_loss_function") + def test_predict_multi_output(self, mock_loss, mock_config, mock_arch): + """Test prediction with multi-output model.""" + mock_model_instance = MagicMock() + mock_model_instance.eval.return_value = None + + mock_arch.return_value.load_from_checkpoint.return_value = mock_model_instance + mock_config.return_value = {} + mock_loss.return_value = MagicMock() + + mock_dataset = MagicMock() + mock_dataset.__len__.return_value = 3 + + with patch("im2deep._model_ops._predict_loop") as mock_predict_loop: + mock_predict_loop.return_value = torch.tensor( + [[450.0, 452.0], [520.0, 524.0], [480.0, 482.0]] + ) + + predictions = _model_ops.predict( + model="fake_model.ckpt", + data=mock_dataset, + multi=True, + ) + + assert isinstance(predictions, torch.Tensor) + assert predictions.shape == (3, 2) + + def test_predict_no_data(self): + """Test prediction without data raises ValueError.""" + with pytest.raises(ValueError, match="Data must be provided"): + _model_ops.predict(model="fake_model.ckpt", data=None) + + +class TestPredictLoop: + """Tests for _predict_loop function.""" + + def test_predict_loop_single_output(self): + """Test prediction loop with single-output model.""" + # Create mock model + model = MagicMock() + model.eval.return_value = None + model.return_value = torch.tensor([[450.0], [520.0]]) + + # Create mock data loader + mock_data = [ + ([torch.randn(2, 10), torch.randn(2, 5)], torch.zeros(2)), + ] + + with patch("im2deep._model_ops.track", return_value=mock_data): + predictions = _model_ops._predict_loop( + model=model, data_loader=mock_data, device="cpu" + ) + + assert isinstance(predictions, torch.Tensor) + + def test_predict_loop_multi_output(self): + """Test prediction loop with multi-output model.""" + # Create mock model that returns tuple + model = MagicMock() + model.eval.return_value = None + model.return_value = ( + torch.tensor([[450.0], [520.0]]), + torch.tensor([[452.0], [524.0]]), + ) + + mock_data = [ + ([torch.randn(2, 10), torch.randn(2, 5)], torch.zeros(2)), + ] + + with patch("im2deep._model_ops.track", return_value=mock_data): + predictions = _model_ops._predict_loop( + model=model, data_loader=mock_data, device="cpu" + ) + + assert isinstance(predictions, torch.Tensor) + # Should stack both outputs + + def test_predict_loop_no_grad(self): + """Test that prediction loop uses no_grad context.""" + model = torch.nn.Linear(10, 1) + model.eval() + + # Create data in the format expected by _predict_loop + # Each batch should be ([features], targets) where features is a list + mock_data = [ + ([torch.randn(2, 10)], torch.randn(2, 1)), + ([torch.randn(2, 10)], torch.randn(2, 1)), + ] + + # Mock track to return our mock data + with patch("im2deep._model_ops.track", return_value=mock_data): + predictions = _model_ops._predict_loop( + model=model, data_loader=mock_data, device="cpu" + ) + + assert not predictions.requires_grad + + +class TestGetArchitecture: + """Tests for _get_architecture function.""" + + @patch("im2deep._architecture.IM2Deep") + def test_get_architecture_single(self, mock_im2deep): + """Test getting single-output architecture.""" + arch = _model_ops._get_architecture(multi=False) + # Should import IM2Deep + assert arch is mock_im2deep + + @patch("im2deep._architecture.IM2DeepMultiTransfer") + def test_get_architecture_multi(self, mock_multi): + """Test getting multi-output architecture.""" + arch = _model_ops._get_architecture(multi=True) + # Should import IM2DeepMultiTransfer + assert arch is mock_multi + + +class TestGetModelConfig: + """Tests for _get_model_config function.""" + + def test_get_model_config_single(self): + """Test getting single-output model config.""" + config = _model_ops._get_model_config(multi=False) + assert isinstance(config, dict) + + def test_get_model_config_multi(self): + """Test getting multi-output model config.""" + config = _model_ops._get_model_config(multi=True) + assert isinstance(config, dict) + + +class TestGetLossFunction: + """Tests for _get_loss_function function.""" + + def test_get_loss_function_single(self): + """Test getting single-output loss function.""" + loss = _model_ops._get_loss_function(multi=False) + assert isinstance(loss, torch.nn.modules.loss._Loss) + + @patch("im2deep._architecture.FlexibleLossSorted") + def test_get_loss_function_multi(self, mock_loss): + """Test getting multi-output loss function.""" + mock_instance = MagicMock() + mock_loss.return_value = mock_instance + loss = _model_ops._get_loss_function(multi=True) + assert loss is mock_instance diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..9582cd6 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,221 @@ +"""Tests for utils module.""" + +import pytest +import pandas as pd +import numpy as np +from pathlib import Path +from psm_utils import PSM, PSMList, Peptidoform + +from im2deep.utils import ( + validate_psm_list, + parse_input, + ccs2im, + im2ccs, +) +from im2deep._exceptions import IM2DeepError + + +class TestValidatePSMList: + """Tests for validate_psm_list function.""" + + def test_validate_psm_list_valid(self, sample_psm_list): + """Test validation with valid PSMList.""" + result = validate_psm_list(sample_psm_list) + assert isinstance(result, PSMList) + assert len(result) == len(sample_psm_list) + + def test_validate_psm_list_with_ccs(self, sample_psm_list_with_ccs): + """Test validation with PSMList containing CCS values.""" + result = validate_psm_list(sample_psm_list_with_ccs, needs_target=True) + assert isinstance(result, PSMList) + for psm in result: + assert "CCS" in psm.metadata + # CCS should always be stored as float + assert isinstance(psm.metadata["CCS"], float) + + def test_validate_psm_list_missing_ccs(self, sample_psm_list): + """Test validation fails when CCS values are required but missing.""" + with pytest.raises(IM2DeepError, match="ion_mobility.*CCS.*metadata"): + validate_psm_list(sample_psm_list, needs_target=True) + + def test_validate_psm_list_empty(self): + """Test validation with empty PSMList.""" + empty_list = PSMList(psm_list=[]) + with pytest.raises(IM2DeepError, match="No PSMs present"): + validate_psm_list(empty_list) + + def test_validate_psm_list_not_psm_list(self): + """Test validation fails with non-PSMList input.""" + with pytest.raises(IM2DeepError, match="PSMList"): + validate_psm_list([1, 2, 3]) + + +class TestParseInput: + """Tests for parse_input function.""" + + def test_parse_input_psm_list(self, sample_psm_list): + """Test parsing PSMList input.""" + result = parse_input(sample_psm_list) + assert isinstance(result, PSMList) + assert len(result) == len(sample_psm_list) + + def test_parse_input_csv_file(self, tmp_path, sample_legacy_format_df): + """Test parsing CSV file.""" + csv_path = tmp_path / "test.csv" + sample_legacy_format_df.to_csv(csv_path, index=False) + + result = parse_input(csv_path) + assert isinstance(result, PSMList) + assert len(result) == len(sample_legacy_format_df) + + def test_parse_input_tsv_file(self, tmp_path, sample_legacy_format_df): + """Test parsing TSV file.""" + tsv_path = tmp_path / "test.tsv" + sample_legacy_format_df.to_csv(tsv_path, sep="\t", index=False) + + result = parse_input(tsv_path) + assert isinstance(result, PSMList) + assert len(result) == len(sample_legacy_format_df) + + def test_parse_input_peprec_format(self, tmp_path, sample_peprec_format_df): + """Test parsing PEPREC format file.""" + csv_path = tmp_path / "peprec.csv" + sample_peprec_format_df.to_csv(csv_path, index=False) + + result = parse_input(csv_path) + assert isinstance(result, PSMList) + assert len(result) == len(sample_peprec_format_df) + + def test_parse_input_with_modifications(self, tmp_path): + """Test parsing file with modifications.""" + df = pd.DataFrame( + { + "seq": ["PEPTIDE", "SEQUENCE"], + "modifications": ["1|Oxidation", ""], + "charge": [2, 3], + } + ) + csv_path = tmp_path / "test_mods.csv" + df.to_csv(csv_path, index=False) + + result = parse_input(csv_path) + assert isinstance(result, PSMList) + assert len(result) == 2 + + def test_parse_input_invalid_file(self, tmp_path): + """Test parsing non-existent file raises error.""" + fake_path = tmp_path / "nonexistent.csv" + with pytest.raises((FileNotFoundError, IM2DeepError)): + parse_input(fake_path) + + def test_parse_input_dataframe(self, sample_legacy_format_df): + """Test parsing DataFrame directly.""" + result = parse_input(sample_legacy_format_df) + assert isinstance(result, PSMList) + assert len(result) == len(sample_legacy_format_df) + + def test_parse_input_legacy_format_detection(self, tmp_path): + """Test that legacy format is properly detected.""" + df = pd.DataFrame( + { + "seq": ["PEPTIDE", "SEQUENCE"], + "modifications": ["", ""], + "charge": [2, 3], + "CCS": [450.5, 520.8], + } + ) + csv_path = tmp_path / "legacy.csv" + df.to_csv(csv_path, index=False) + + result = parse_input(csv_path) + assert isinstance(result, PSMList) + assert len(result) == 2 + # Check that CCS values are preserved in metadata + for psm in result: + assert "CCS" in psm.metadata + + +class TestCCSConversions: + """Tests for CCS and ion mobility conversion functions.""" + + def test_ccs2im_basic(self): + """Test basic CCS to ion mobility conversion.""" + ccs = 450.0 + charge = 2 + mz = 500.0 + im = ccs2im(ccs, charge, mz) + + assert isinstance(im, float) + assert im > 0 + + def test_ccs2im_array(self): + """Test CCS to ion mobility conversion with arrays.""" + ccs = np.array([450.0, 520.0, 480.0]) + charge = np.array([2, 3, 2]) + mz = np.array([500.0, 600.0, 550.0]) + + im = ccs2im(ccs, charge, mz) + + assert isinstance(im, np.ndarray) + assert len(im) == len(ccs) + assert np.all(im > 0) + + def test_im2ccs_basic(self): + """Test basic ion mobility to CCS conversion.""" + im = 1.0 + charge = 2 + mz = 500.0 + ccs = im2ccs(im, charge, mz) + + assert isinstance(ccs, float) + assert ccs > 0 + + def test_im2ccs_array(self): + """Test ion mobility to CCS conversion with arrays.""" + im = np.array([1.0, 1.2, 0.9]) + charge = np.array([2, 3, 2]) + mz = np.array([500.0, 600.0, 550.0]) + + ccs = im2ccs(im, charge, mz) + + assert isinstance(ccs, np.ndarray) + assert len(ccs) == len(im) + assert np.all(ccs > 0) + + def test_ccs2im_im2ccs_roundtrip(self): + """Test that CCS -> IM -> CCS conversion is consistent.""" + ccs_original = 450.0 + charge = 2 + mz = 500.0 + + im = ccs2im(ccs_original, charge, mz) + ccs_roundtrip = im2ccs(im, charge, mz) + + assert abs(ccs_roundtrip - ccs_original) < 0.01 + + def test_ccs2im_zero_values(self): + """Test handling of zero values.""" + with pytest.raises((ValueError, ZeroDivisionError)): + ccs2im(0, 2, 500.0) + + def test_im2ccs_zero_values(self): + """Test handling of zero values.""" + with pytest.raises((ValueError, ZeroDivisionError)): + im2ccs(0, 2, 500.0) + + def test_ccs2im_negative_values(self): + """Test handling of negative values.""" + # Function should raise ValueError for negative CCS values + with pytest.raises(ValueError, match="CCS must be positive"): + ccs2im(-450.0, 2, 500.0) + + def test_im2ccs_different_charges(self): + """Test conversions with different charge states.""" + im = 1.0 + mz = 500.0 + + ccs_z2 = im2ccs(im, 2, mz) + ccs_z3 = im2ccs(im, 3, mz) + + assert ccs_z2 != ccs_z3 + assert ccs_z2 > 0 and ccs_z3 > 0