Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/athena/neighborhood/base_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,24 @@ class Interactions:
VALID_MODES = ['classic', 'histoCAT', 'proportion']
VALID_PREDICTION_TYPES = ['pvalue', 'observation', 'diff']

def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', n_permutations: int = 500,
def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', aggregation: str = 'mean',
n_permutations: int = 500,
random_seed=42, alpha: float = .01, graph_key: str = 'knn'):
"""Estimator to quantify interaction strength between different species in the sample.

Args:
so: SpatialOmics
spl: Sample for which to compute the interaction strength
attr: Categorical feature in ad.obs to use for the grouping
mode: One of {classic, histoCAT, proportion}, see notes
mode: One of {classic, proportion}, see notes
aggregation: How to aggregate the observed interactions for a given source node
n_permutations: Number of permutations to compute p-values and the interactions strength score (mode diff)
random_seed: Random seed for permutations
alpha: Threshold for significance
graph_key: Specifies the graph representation to use in so.G[spl] if `local=True`.

Notes:
classic and histoCAT are python implementations of the corresponding methods published by the Bodenmiller lab at UZH.
`classic` counts for each pair-wise interaction the number of edges between the two species.
The proportion method is similar to the classic method but normalises the score by the number of edges and is thus bound [0,1].
"""

Expand All @@ -79,6 +81,7 @@ def __init__(self, ad: AnnData, attr: str = 'meta_id', mode: str = 'classic', n_
self.attr: str = attr
self.data: pd.Series = ad.obs[attr]
self.mode: str = mode
self.aggregation: str = aggregation
self.n_perm: int = int(n_permutations)
self.random_seed = random_seed
self.rng = np.random.default_rng(random_seed)
Expand Down Expand Up @@ -122,7 +125,8 @@ def fit(self, prediction_type: str = 'observation') -> None:
raise ValueError(f'invalid mode {self.mode}. Available modes are {self.VALID_MODES}')

node_interactions = get_node_interactions(self.g, self.data)
obs_interaction = get_interaction_score(node_interactions, relative_freq=relative_freq, observed=observed)
obs_interaction = get_interaction_score(node_interactions, aggregation=self.aggregation,
relative_freq=relative_freq, observed=observed)
self.obs_interaction = obs_interaction.set_index(['source_label', 'target_label'])

if not prediction_type == 'observation':
Expand Down
7 changes: 4 additions & 3 deletions src/athena/neighborhood/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# %%
def interactions(ad: AnnData, *, attr: str,
mode: str = 'classic', prediction_type: str = 'observation',
mode: str = 'classic', prediction_type: str = 'observation', aggregation: str = 'mean',
n_permutations: int = 100,
random_seed=42, alpha: float = .01, key_added: str = None,
graph_key: str = 'knn',
Expand All @@ -28,6 +28,7 @@ def interactions(ad: AnnData, *, attr: str,
n_permutations: Number of permutations to compute p-values and the interactions strength score (mode diff)
random_seed: Random seed for permutations
alpha: Threshold for significance
aggregation: 'mean', 'sum'
prediction_type: One of {observation, pvalue, diff}, see Notes
key_added: Key added to SpatialOmics.uns[spl][metric][key_added]
graph_key: Specifies the graph representation to use in ad.obsp if `local=True`.
Expand All @@ -44,12 +45,12 @@ def interactions(ad: AnnData, *, attr: str,

# NOTE: uns_path = f'{spl}/interactions/'
if key_added is None:
key_added = f'interaction_{attr}_{mode}_{prediction_type}_{graph_key}'
key_added = f'interaction_{attr}_{mode}_{prediction_type}_{aggregation}_{graph_key}'

if random_seed is None:
random_seed = 42

estimator = Interactions(ad=ad, attr=attr, mode=mode, n_permutations=n_permutations,
estimator = Interactions(ad=ad, attr=attr, mode=mode, aggregation=aggregation, n_permutations=n_permutations,
random_seed=random_seed, alpha=alpha, graph_key=graph_key)

estimator.fit(prediction_type=prediction_type)
Expand Down
6 changes: 3 additions & 3 deletions src/athena/neighborhood/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_node_interactions(g: nx.Graph, data: pd.Series = None):
return node_interactions


def get_interaction_score(interactions, relative_freq=False, observed=False):
def get_interaction_score(interactions, aggregation='mean', relative_freq=False, observed=False):
# NOTE: this is not necessarily len(source_labels) == len(g) since only source nodes with neighbors are included
source_label = interactions[['source', 'source_label']].drop_duplicates().set_index('source')
source_label = source_label.squeeze()
Expand All @@ -53,14 +53,14 @@ def get_interaction_score(interactions, relative_freq=False, observed=False):
source2target_label['relative_freq'] = source2target_label['counts'] / source2target_label['n_neigh']
label2label = source2target_label\
.groupby(['source_label', 'target_label'], observed=observed)['relative_freq'] \
.agg('mean') \
.agg(aggregation) \
.rename('score') \
.fillna(0) \
.reset_index()
else:
label2label = source2target_label \
.groupby(['source_label', 'target_label'], observed=observed)['counts'] \
.agg('mean') \
.agg(aggregation) \
.rename('score') \
.fillna(0) \
.reset_index()
Expand Down
13 changes: 12 additions & 1 deletion tests/neighborhood/test_interaction.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
import athena as ath
ad = ath.dataset.imc()
import anndata
anndata.settings.allow_write_nullable_strings = True
import pickle

with open('/work/FAC/FBM/DBC/mrapsoma/prometex/data/PCa/anndatas/raw_bc2/231204_ibl_x2y4_29_11.pkl', 'rb') as f:
ad = pickle.load(file=f)

ath.neigh.interactions(ad=ad, attr='label', graph_key='radius_32',
mode='classic', prediction_type='observation', aggregation='sum')

ath.neigh.interactions(ad=ad, attr='label', graph_key='radius_32',
mode='classic', prediction_type='observation', aggregation='mean')