diff --git a/src/athena/neighborhood/base_estimators.py b/src/athena/neighborhood/base_estimators.py index b5e70c3..32fb6c0 100644 --- a/src/athena/neighborhood/base_estimators.py +++ b/src/athena/neighborhood/base_estimators.py @@ -54,7 +54,8 @@ class Interactions: VALID_MODES = ['classic', 'histoCAT', 'proportion'] VALID_PREDICTION_TYPES = ['pvalue', 'observation', 'diff'] - def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', n_permutations: int = 500, + def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', aggregation: str = 'mean', + n_permutations: int = 500, random_seed=42, alpha: float = .01, graph_key: str = 'knn'): """Estimator to quantify interaction strength between different species in the sample. @@ -62,14 +63,15 @@ def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', n_ so: SpatialOmics spl: Sample for which to compute the interaction strength attr: Categorical feature in ad.obs to use for the grouping - mode: One of {classic, histoCAT, proportion}, see notes + mode: One of {classic, proportion}, see notes + aggregation: How to aggregate the observed interactions for a given source node n_permutations: Number of permutations to compute p-values and the interactions strength score (mode diff) random_seed: Random seed for permutations alpha: Threshold for significance graph_key: Specifies the graph representation to use in so.G[spl] if `local=True`. Notes: - classic and histoCAT are python implementations of the corresponding methods published by the Bodenmiller lab at UZH. + `classic` counts for each pair-wise interaction the number of edges between the two species. The proportion method is similar to the classic method but normalises the score by the number of edges and is thus bound [0,1]. """ @@ -79,6 +81,7 @@ def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', n_ self.attr: str = attr self.data: pd.Series = ad.obs[attr] self.mode: str = mode + self.aggregation: str = aggregation self.n_perm: int = int(n_permutations) self.random_seed = random_seed self.rng = np.random.default_rng(random_seed) @@ -122,7 +125,8 @@ def fit(self, prediction_type: str = 'observation') -> None: raise ValueError(f'invalid mode {self.mode}. Available modes are {self.VALID_MODES}') node_interactions = get_node_interactions(self.g, self.data) - obs_interaction = get_interaction_score(node_interactions, relative_freq=relative_freq, observed=observed) + obs_interaction = get_interaction_score(node_interactions, aggregation=self.aggregation, + relative_freq=relative_freq, observed=observed) self.obs_interaction = obs_interaction.set_index(['source_label', 'target_label']) if not prediction_type == 'observation': diff --git a/src/athena/neighborhood/estimators.py b/src/athena/neighborhood/estimators.py index b435097..594894c 100644 --- a/src/athena/neighborhood/estimators.py +++ b/src/athena/neighborhood/estimators.py @@ -11,7 +11,7 @@ # %% def interactions(ad: AnnData, *, attr: str, - mode: str = 'classic', prediction_type: str = 'observation', + mode: str = 'classic', prediction_type: str = 'observation', aggregation: str = 'mean', n_permutations: int = 100, random_seed=42, alpha: float = .01, key_added: str = None, graph_key: str = 'knn', @@ -28,6 +28,7 @@ def interactions(ad: AnnData, *, attr: str, n_permutations: Number of permutations to compute p-values and the interactions strength score (mode diff) random_seed: Random seed for permutations alpha: Threshold for significance + aggregation: 'mean', 'sum' prediction_type: One of {observation, pvalue, diff}, see Notes key_added: Key added to SpatialOmics.uns[spl][metric][key_added] graph_key: Specifies the graph representation to use in ad.obsp if `local=True`. @@ -44,12 +45,12 @@ def interactions(ad: AnnData, *, attr: str, # NOTE: uns_path = f'{spl}/interactions/' if key_added is None: - key_added = f'interaction_{attr}_{mode}_{prediction_type}_{graph_key}' + key_added = f'interaction_{attr}_{mode}_{prediction_type}_{aggregation}_{graph_key}' if random_seed is None: random_seed = 42 - estimator = Interactions(ad=ad, attr=attr, mode=mode, n_permutations=n_permutations, + estimator = Interactions(ad=ad, attr=attr, mode=mode, aggregation=aggregation, n_permutations=n_permutations, random_seed=random_seed, alpha=alpha, graph_key=graph_key) estimator.fit(prediction_type=prediction_type) diff --git a/src/athena/neighborhood/utils.py b/src/athena/neighborhood/utils.py index b61990f..d1a5bd8 100644 --- a/src/athena/neighborhood/utils.py +++ b/src/athena/neighborhood/utils.py @@ -38,7 +38,7 @@ def get_node_interactions(g: nx.Graph, data: pd.Series = None): return node_interactions -def get_interaction_score(interactions, relative_freq=False, observed=False): +def get_interaction_score(interactions, aggregation='mean', relative_freq=False, observed=False): # NOTE: this is not necessarily len(source_labels) == len(g) since only source nodes with neighbors are included source_label = interactions[['source', 'source_label']].drop_duplicates().set_index('source') source_label = source_label.squeeze() @@ -53,14 +53,14 @@ def get_interaction_score(interactions, relative_freq=False, observed=False): source2target_label['relative_freq'] = source2target_label['counts'] / source2target_label['n_neigh'] label2label = source2target_label\ .groupby(['source_label', 'target_label'], observed=observed)['relative_freq'] \ - .agg('mean') \ + .agg(aggregation) \ .rename('score') \ .fillna(0) \ .reset_index() else: label2label = source2target_label \ .groupby(['source_label', 'target_label'], observed=observed)['counts'] \ - .agg('mean') \ + .agg(aggregation) \ .rename('score') \ .fillna(0) \ .reset_index() diff --git a/tests/neighborhood/test_interaction.py b/tests/neighborhood/test_interaction.py index ca7c29d..16cb9eb 100644 --- a/tests/neighborhood/test_interaction.py +++ b/tests/neighborhood/test_interaction.py @@ -1,2 +1,13 @@ import athena as ath -ad = ath.dataset.imc() \ No newline at end of file +import anndata +anndata.settings.allow_write_nullable_strings = True +import pickle + +with open('/work/FAC/FBM/DBC/mrapsoma/prometex/data/PCa/anndatas/raw_bc2/231204_ibl_x2y4_29_11.pkl', 'rb') as f: + ad = pickle.load(file=f) + +ath.neigh.interactions(ad=ad, attr='label', graph_key='radius_32', + mode='classic', prediction_type='observation', aggregation='sum') + +ath.neigh.interactions(ad=ad, attr='label', graph_key='radius_32', + mode='classic', prediction_type='observation', aggregation='mean')