diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index a2aad1c..5537d33 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -49,6 +49,9 @@ jobs: - name: Test CLI - WLS growth run: | bash test/test-growth-wls.sh + - name: Test CLI - WLS growth, multiple references + run: | + bash test/test-growth-wls-multiref.sh - name: Test CLI - OLS growth run: | bash test/test-growth-ols.sh diff --git a/bartab/io.py b/bartab/io.py index 329a39e..377b7ea 100644 --- a/bartab/io.py +++ b/bartab/io.py @@ -42,8 +42,10 @@ def _check_control( ref, name: str, check_presence: bool = False, - sep: str = "::" + sep: str = "::", + loose: bool = False ): + from pandas.api.types import is_string_dtype new_col = f"__is_{name}__" if ref is None: df[new_col] = False @@ -55,10 +57,15 @@ def _check_control( elif col in df.index.names: vals_to_check = df.index.get_level_values(col) elif col == "__index__": - vals_to_check = df.index.values + vals_to_check = df.index else: raise KeyError(f"Column {col} not in data") - df[new_col] = vals_to_check == ref + if not loose: + df[new_col] = vals_to_check == ref + elif is_string_dtype(vals_to_check): + df[new_col] = vals_to_check.str.startswith(ref) + else: + raise ValueError(f"Cannot do loose reference matching on non-string types: {vals_to_check}") n_refs = df[new_col].sum() if n_refs == 0 and check_presence: raise ValueError(f"No reference samples '{ref}' identified in '{col}'") @@ -147,7 +154,7 @@ def load_anndata( strain_meta = strain_meta.set_index("__index__").loc[counts_wide.index].copy() sample_meta = _check_control(sample_meta, timepoint_column, t0, "t0", check_presence=True) - strain_meta = _check_control(strain_meta, "__index__", reference, "reference", check_presence=True) + strain_meta = _check_control(strain_meta, "__index__", reference, "reference", check_presence=True, loose=True) strain_meta = _check_control(strain_meta, "__index__", spike, "spike") adata = AnnData( diff --git a/bartab/models/anndata.py b/bartab/models/anndata.py index dc8906c..3294b05 100644 --- a/bartab/models/anndata.py +++ b/bartab/models/anndata.py @@ -87,14 +87,22 @@ def fit( results_df.set_index(_index_col), how="left", validate="one_to_one", - left_index=True, right_index=True + left_index=True, + right_index=True, ) adata.uns["models_fitted"].append(name) return adata class AnnDataWLSModel(AnnDataModel, WLSModel): - def fit(self, adata, *args, weight_kwargs=None, concentration_key: str = "__inducer__", **kwargs): + def fit( + self, + adata, + *args, + weight_kwargs: None = None, + concentration_key: str = "__inducer__", + **kwargs + ): return super().fit( adata, *args, @@ -113,7 +121,14 @@ def fit(self, adata, *args, weight_kwargs=None, concentration_key: str = "__indu class AnnDataOLSModel(AnnDataModel, OLSModel): - def fit(self, adata, *args, weights=None, concentration_key: str = "__inducer__", **kwargs): + def fit( + self, + adata, + *args, + weights: None = None, + concentration_key: str = "__inducer__", + **kwargs + ): return super().fit( adata, *args, @@ -125,8 +140,15 @@ def fit(self, adata, *args, weights=None, concentration_key: str = "__inducer__" class AnnDataHillModel(AnnDataModel, HillFitnessModel): - def fit(self, adata, *args, concentration: str, weight_kwargs=None, concentration_key: str = "__inducer__", **kwargs): - print(adata.var[concentration_key].values) + def fit( + self, + adata, + *args, + concentration: str, + weight_kwargs: None = None, + concentration_key: str = "__inducer__", + **kwargs + ): return super().fit( adata, *args, diff --git a/bartab/models/base.py b/bartab/models/base.py index 31c0722..4ca52a3 100644 --- a/bartab/models/base.py +++ b/bartab/models/base.py @@ -1,5 +1,6 @@ """Base classes for fitness models.""" -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import Any +from collections.abc import Callable, Iterable, Mapping from abc import ABC, abstractmethod from carabiner import print_err @@ -50,7 +51,7 @@ def _calculate_weights_matrix( def calculate_weights_matrix( self, Y: ArrayLike, - weights: Optional[ArrayLike] = None, + weights: ArrayLike | None = None, **kwargs ) -> np.ndarray: if weights is None: @@ -63,15 +64,14 @@ def fit_obs( self, y: ArrayLike, x: ArrayLike, - valid: Optional[ArrayLike] = None, - weights: Optional[ArrayLike] = None, - param_names: Optional[str] = None, + valid: ArrayLike | None = None, + weights: ArrayLike | None = None, + param_names: str | None = None, **kwargs - ) -> Dict[str, Union[float, int]]: + ) -> dict[str, float | int]: n_orig = len(y) if valid is None: valid = np.ones_like(y, dtype=bool) - y_full = y.copy() # keep for padding later y = y[valid] x = x[valid] if weights is not None: @@ -96,21 +96,23 @@ def fit_obs( f"{k}_ci_{j}": _v for k, v in cis.items() for j, _v in zip(["low", "high"], v) - } | {"nobs": y.shape[0]} | other + } | { + "nobs": y.shape[0], + } | other return self._fitness_transform(result), (x_out, y_out, preds_full) def fit( self, Y: ArrayLike, x: ArrayLike, - valid: Optional[ArrayLike] = None, - weights: Optional[ArrayLike] = None, - param_names: Optional[str] = None, - groups: Optional[ArrayLike] = None, + valid: ArrayLike | None = None, + weights: ArrayLike | None = None, + param_names: str | None = None, + groups: ArrayLike | None = None, min_obs: int = 3, - weight_kwargs: Optional[Mapping[str, Any]] = None, + weight_kwargs: Mapping[str, Any] | None = None, **kwargs - ) -> List[Dict[str, Union[float, int]]]: + ) -> list[dict[str, float | int]]: if valid is None: valid = np.ones(Y.shape[0], dtype=bool) @@ -201,8 +203,8 @@ def _fit( self, y: ArrayLike, x: ArrayLike, - param_names: Optional[Iterable[str]] = None, - method: Optional[Callable] = None, + param_names: Iterable[str] | None = None, + method: Callable | None = None, **kwargs ): if method is None: @@ -232,7 +234,7 @@ def _fit( model_fn: Callable, init_params: ArrayLike, weights: ArrayLike, - param_names: Optional[Iterable[str]] = None, + param_names: Iterable[str] | None = None, **kwargs ): from scipy.optimize import curve_fit diff --git a/bartab/models/linear.py b/bartab/models/linear.py index d67aadc..f3575fb 100644 --- a/bartab/models/linear.py +++ b/bartab/models/linear.py @@ -11,6 +11,7 @@ def _delta_method_weights( raw: ArrayLike, # (n_strains, n_samples) raw counts + pseudocount control_mask: ArrayLike, # (n_strains,) dispersion: ArrayLike, # (n_strains,) per-strain alpha + groups: Optional[ArrayLike] = None ) -> np.ndarray: """ @@ -31,8 +32,8 @@ def _delta_method_weights( True """ - ref_counts = raw[control_mask, :].squeeze(axis=0) # (n_samples,) - ref_disp = dispersion[control_mask].squeeze() # scalar + ref_counts = raw[control_mask, :].sum(axis=0) # (n_samples,) + ref_disp = _estimate_dispersion_mom(ref_counts[None], groups) # scalar var_y = ( 1. / raw + dispersion[:, None] # (n_strains, n_samples) + 1. / ref_counts + ref_disp @@ -101,7 +102,7 @@ def _calculate_weights_matrix( raw, groups, ) - return _delta_method_weights(raw, control_mask, dispersion) # (n_strains, n_samples) + return _delta_method_weights(raw, control_mask, dispersion, groups) # (n_strains, n_samples) @staticmethod def _fitness_transform(results: Mapping[str, float]) -> Dict[str, float]: diff --git a/bartab/plotting.py b/bartab/plotting.py index d126434..3de9fbc 100644 --- a/bartab/plotting.py +++ b/bartab/plotting.py @@ -57,7 +57,11 @@ def scatter( **kwargs, ) if "label" in scatter_opts: - add_legend(ax) + if all([ + len(scatter_opts["label"]) > 0, + scatter_opts["label"][0] != "_", + ]): + add_legend(ax) return ax @@ -69,6 +73,7 @@ def _avoid_color_collision(i, avoid=None): else: return i + def volcano( adata, model_name, @@ -94,7 +99,7 @@ def volcano( ax=ax, x=x, y=y, - data=df.query("strain_id.str.startswith(@control_prefix)"), + data=df.query("__row_index__.str.startswith(@control_prefix)"), scatter_opts={ "facecolor": "dimgrey", "edgecolor": "none", diff --git a/bartab/transforms.py b/bartab/transforms.py index bc40d7c..3708d81 100644 --- a/bartab/transforms.py +++ b/bartab/transforms.py @@ -47,13 +47,16 @@ def compute_log_ratios( ref_mask = adata.obs["__is_reference__"].values # (n_strains,) spike_mask = adata.obs["__is_spike__"].values # (n_strains,) - if ref_mask.sum() != 1: - raise ValueError(f"Expected exactly 1 reference strain, found {ref_mask.sum()}") + n_ref = ref_mask.sum() + if n_ref == 0: + raise ValueError("No reference barcodes found") + if n_ref > 0: + print_err(f"Found {n_ref} reference barcodes: {', '.join(adata.obs.index[ref_mask].astype(str))}") if t0_mask.sum() == 0: raise ValueError("No t0 samples found") log_X = np.log(X) # (n_strains, n_samples) - log_ref = log_X[ref_mask, :] # (1, n_samples,) + log_ref = np.log(np.sum(X[ref_mask, :], axis=0, keepdims=True)) # (1, n_samples,) # log(c_i / c_wt) at every sample log_ratio_to_ref = log_X - log_ref # (n_strains, n_samples) diff --git a/pyproject.toml b/pyproject.toml index 443c454..43ff747 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "bartab" -version = "0.0.6" +version = "0.0.7" authors = [ { name="Eachan Johnson", email="eachan.johnson@crick.ac.uk" }, ] diff --git a/test/test-growth-wls-multiref.sh b/test/test-growth-wls-multiref.sh new file mode 100644 index 0000000..539f019 --- /dev/null +++ b/test/test-growth-wls-multiref.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -euox pipefail + +INPUT="test/inputs/test_count.csv" +SAMPLE_SHEET="test/inputs/test_sample_meta.csv" +STRAINS="test/inputs/test_strain_meta.csv" + +bartab fit \ + "$INPUT" \ + --output "test/outputs/wls-growth-results-multiref.h5ad" \ + --sample-sheet "$SAMPLE_SHEET" \ + --barcode-sheet "$STRAINS" \ + --barcode-column "strain_id" \ + --culture-column "replicate" \ + --growth-column "growth" \ + --growth-type "density" \ + --reference "mut" \ + --spike-name "spike" + +bartab plot \ + "test/outputs/wls-growth-results-multiref.h5ad" \ + --highlight mut_A mut_B mut_C mut_D \ + -o test/outputs/wls-growth-plot-multiref + +>&2 echo "[$(date)] Done!"