Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ jobs:
- name: Test CLI - WLS growth
run: |
bash test/test-growth-wls.sh
- name: Test CLI - WLS growth, multiple references
run: |
bash test/test-growth-wls-multiref.sh
- name: Test CLI - OLS growth
run: |
bash test/test-growth-ols.sh
Expand Down
15 changes: 11 additions & 4 deletions bartab/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def _check_control(
ref,
name: str,
check_presence: bool = False,
sep: str = "::"
sep: str = "::",
loose: bool = False
):
from pandas.api.types import is_string_dtype
new_col = f"__is_{name}__"
if ref is None:
df[new_col] = False
Expand All @@ -55,10 +57,15 @@ def _check_control(
elif col in df.index.names:
vals_to_check = df.index.get_level_values(col)
elif col == "__index__":
vals_to_check = df.index.values
vals_to_check = df.index
else:
raise KeyError(f"Column {col} not in data")
df[new_col] = vals_to_check == ref
if not loose:
df[new_col] = vals_to_check == ref
elif is_string_dtype(vals_to_check):
df[new_col] = vals_to_check.str.startswith(ref)
else:
raise ValueError(f"Cannot do loose reference matching on non-string types: {vals_to_check}")
n_refs = df[new_col].sum()
if n_refs == 0 and check_presence:
raise ValueError(f"No reference samples '{ref}' identified in '{col}'")
Expand Down Expand Up @@ -147,7 +154,7 @@ def load_anndata(
strain_meta = strain_meta.set_index("__index__").loc[counts_wide.index].copy()

sample_meta = _check_control(sample_meta, timepoint_column, t0, "t0", check_presence=True)
strain_meta = _check_control(strain_meta, "__index__", reference, "reference", check_presence=True)
strain_meta = _check_control(strain_meta, "__index__", reference, "reference", check_presence=True, loose=True)
strain_meta = _check_control(strain_meta, "__index__", spike, "spike")

adata = AnnData(
Expand Down
32 changes: 27 additions & 5 deletions bartab/models/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,22 @@ def fit(
results_df.set_index(_index_col),
how="left",
validate="one_to_one",
left_index=True, right_index=True
left_index=True,
right_index=True,
)
adata.uns["models_fitted"].append(name)
return adata


class AnnDataWLSModel(AnnDataModel, WLSModel):
def fit(self, adata, *args, weight_kwargs=None, concentration_key: str = "__inducer__", **kwargs):
def fit(
self,
adata,
*args,
weight_kwargs: None = None,
concentration_key: str = "__inducer__",
**kwargs
):
return super().fit(
adata,
*args,
Expand All @@ -113,7 +121,14 @@ def fit(self, adata, *args, weight_kwargs=None, concentration_key: str = "__indu


class AnnDataOLSModel(AnnDataModel, OLSModel):
def fit(self, adata, *args, weights=None, concentration_key: str = "__inducer__", **kwargs):
def fit(
self,
adata,
*args,
weights: None = None,
concentration_key: str = "__inducer__",
**kwargs
):
return super().fit(
adata,
*args,
Expand All @@ -125,8 +140,15 @@ def fit(self, adata, *args, weights=None, concentration_key: str = "__inducer__"


class AnnDataHillModel(AnnDataModel, HillFitnessModel):
def fit(self, adata, *args, concentration: str, weight_kwargs=None, concentration_key: str = "__inducer__", **kwargs):
print(adata.var[concentration_key].values)
def fit(
self,
adata,
*args,
concentration: str,
weight_kwargs: None = None,
concentration_key: str = "__inducer__",
**kwargs
):
return super().fit(
adata,
*args,
Expand Down
36 changes: 19 additions & 17 deletions bartab/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for fitness models."""
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from typing import Any
from collections.abc import Callable, Iterable, Mapping
from abc import ABC, abstractmethod

from carabiner import print_err
Expand Down Expand Up @@ -50,7 +51,7 @@ def _calculate_weights_matrix(
def calculate_weights_matrix(
self,
Y: ArrayLike,
weights: Optional[ArrayLike] = None,
weights: ArrayLike | None = None,
**kwargs
) -> np.ndarray:
if weights is None:
Expand All @@ -63,15 +64,14 @@ def fit_obs(
self,
y: ArrayLike,
x: ArrayLike,
valid: Optional[ArrayLike] = None,
weights: Optional[ArrayLike] = None,
param_names: Optional[str] = None,
valid: ArrayLike | None = None,
weights: ArrayLike | None = None,
param_names: str | None = None,
**kwargs
) -> Dict[str, Union[float, int]]:
) -> dict[str, float | int]:
n_orig = len(y)
if valid is None:
valid = np.ones_like(y, dtype=bool)
y_full = y.copy() # keep for padding later
y = y[valid]
x = x[valid]
if weights is not None:
Expand All @@ -96,21 +96,23 @@ def fit_obs(
f"{k}_ci_{j}": _v
for k, v in cis.items()
for j, _v in zip(["low", "high"], v)
} | {"nobs": y.shape[0]} | other
} | {
"nobs": y.shape[0],
} | other
return self._fitness_transform(result), (x_out, y_out, preds_full)

def fit(
self,
Y: ArrayLike,
x: ArrayLike,
valid: Optional[ArrayLike] = None,
weights: Optional[ArrayLike] = None,
param_names: Optional[str] = None,
groups: Optional[ArrayLike] = None,
valid: ArrayLike | None = None,
weights: ArrayLike | None = None,
param_names: str | None = None,
groups: ArrayLike | None = None,
min_obs: int = 3,
weight_kwargs: Optional[Mapping[str, Any]] = None,
weight_kwargs: Mapping[str, Any] | None = None,
**kwargs
) -> List[Dict[str, Union[float, int]]]:
) -> list[dict[str, float | int]]:

if valid is None:
valid = np.ones(Y.shape[0], dtype=bool)
Expand Down Expand Up @@ -201,8 +203,8 @@ def _fit(
self,
y: ArrayLike,
x: ArrayLike,
param_names: Optional[Iterable[str]] = None,
method: Optional[Callable] = None,
param_names: Iterable[str] | None = None,
method: Callable | None = None,
**kwargs
):
if method is None:
Expand Down Expand Up @@ -232,7 +234,7 @@ def _fit(
model_fn: Callable,
init_params: ArrayLike,
weights: ArrayLike,
param_names: Optional[Iterable[str]] = None,
param_names: Iterable[str] | None = None,
**kwargs
):
from scipy.optimize import curve_fit
Expand Down
7 changes: 4 additions & 3 deletions bartab/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def _delta_method_weights(
raw: ArrayLike, # (n_strains, n_samples) raw counts + pseudocount
control_mask: ArrayLike, # (n_strains,)
dispersion: ArrayLike, # (n_strains,) per-strain alpha
groups: Optional[ArrayLike] = None
) -> np.ndarray:
"""

Expand All @@ -31,8 +32,8 @@ def _delta_method_weights(
True

"""
ref_counts = raw[control_mask, :].squeeze(axis=0) # (n_samples,)
ref_disp = dispersion[control_mask].squeeze() # scalar
ref_counts = raw[control_mask, :].sum(axis=0) # (n_samples,)
ref_disp = _estimate_dispersion_mom(ref_counts[None], groups) # scalar
var_y = (
1. / raw + dispersion[:, None] # (n_strains, n_samples)
+ 1. / ref_counts + ref_disp
Expand Down Expand Up @@ -101,7 +102,7 @@ def _calculate_weights_matrix(
raw,
groups,
)
return _delta_method_weights(raw, control_mask, dispersion) # (n_strains, n_samples)
return _delta_method_weights(raw, control_mask, dispersion, groups) # (n_strains, n_samples)

@staticmethod
def _fitness_transform(results: Mapping[str, float]) -> Dict[str, float]:
Expand Down
9 changes: 7 additions & 2 deletions bartab/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def scatter(
**kwargs,
)
if "label" in scatter_opts:
add_legend(ax)
if all([
len(scatter_opts["label"]) > 0,
scatter_opts["label"][0] != "_",
]):
add_legend(ax)
return ax


Expand All @@ -69,6 +73,7 @@ def _avoid_color_collision(i, avoid=None):
else:
return i


def volcano(
adata,
model_name,
Expand All @@ -94,7 +99,7 @@ def volcano(
ax=ax,
x=x,
y=y,
data=df.query("strain_id.str.startswith(@control_prefix)"),
data=df.query("__row_index__.str.startswith(@control_prefix)"),
scatter_opts={
"facecolor": "dimgrey",
"edgecolor": "none",
Expand Down
9 changes: 6 additions & 3 deletions bartab/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ def compute_log_ratios(
ref_mask = adata.obs["__is_reference__"].values # (n_strains,)
spike_mask = adata.obs["__is_spike__"].values # (n_strains,)

if ref_mask.sum() != 1:
raise ValueError(f"Expected exactly 1 reference strain, found {ref_mask.sum()}")
n_ref = ref_mask.sum()
if n_ref == 0:
raise ValueError("No reference barcodes found")
if n_ref > 0:
print_err(f"Found {n_ref} reference barcodes: {', '.join(adata.obs.index[ref_mask].astype(str))}")
if t0_mask.sum() == 0:
raise ValueError("No t0 samples found")

log_X = np.log(X) # (n_strains, n_samples)
log_ref = log_X[ref_mask, :] # (1, n_samples,)
log_ref = np.log(np.sum(X[ref_mask, :], axis=0, keepdims=True)) # (1, n_samples,)

# log(c_i / c_wt) at every sample
log_ratio_to_ref = log_X - log_ref # (n_strains, n_samples)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "bartab"
version = "0.0.6"
version = "0.0.7"
authors = [
{ name="Eachan Johnson", email="eachan.johnson@crick.ac.uk" },
]
Expand Down
26 changes: 26 additions & 0 deletions test/test-growth-wls-multiref.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env bash

set -euox pipefail

INPUT="test/inputs/test_count.csv"
SAMPLE_SHEET="test/inputs/test_sample_meta.csv"
STRAINS="test/inputs/test_strain_meta.csv"

bartab fit \
"$INPUT" \
--output "test/outputs/wls-growth-results-multiref.h5ad" \
--sample-sheet "$SAMPLE_SHEET" \
--barcode-sheet "$STRAINS" \
--barcode-column "strain_id" \
--culture-column "replicate" \
--growth-column "growth" \
--growth-type "density" \
--reference "mut" \
--spike-name "spike"

bartab plot \
"test/outputs/wls-growth-results-multiref.h5ad" \
--highlight mut_A mut_B mut_C mut_D \
-o test/outputs/wls-growth-plot-multiref

>&2 echo "[$(date)] Done!"
Loading