Skip to content

Commit

Permalink
Some cleanup and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Oct 14, 2024
1 parent 7d28009 commit 140cf79
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 53 deletions.
128 changes: 75 additions & 53 deletions plm_interpretability/latent_probe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,25 @@
class ResidueAnnotation:
name: str
swissprot_header: str
values: list[str]
class_names: list[str]

# Value used to indicate that we don't care about the notes on the annotation,
# Class name used to indicate that we don't care about the annotation class,
# as long as the annotation exists. E.g. signal peptides annotations look like
# `{'start': 1, 'end': 24, 'evidence': 'ECO:0000255'}`, so we just classify
# whether the residue is part of a signal peptide or not.
ALL = "all"
ALL_CLASSES = "all"


RESIDUE_ANNOTATIONS = [
ResidueAnnotation(
name="DNA binding",
swissprot_header="DNA_BIND",
values=["H-T-H motif", "Homeobox", "Nuclear receptor", "HMG box"],
class_names=["H-T-H motif", "Homeobox", "Nuclear receptor", "HMG box"],
),
ResidueAnnotation(
name="Motif",
swissprot_header="MOTIF",
values=[
class_names=[
"Nuclear localization signal",
"Nuclear export signal",
"DEAD box",
Expand All @@ -59,7 +59,7 @@ class ResidueAnnotation:
ResidueAnnotation(
name="Topological domain",
swissprot_header="TOPO_DOM",
values=[
class_names=[
"Cytoplasmic",
"Extracellular",
"Lumenal",
Expand All @@ -73,7 +73,7 @@ class ResidueAnnotation:
ResidueAnnotation(
name="Domain [FT]",
swissprot_header="DOMAIN",
values=[
class_names=[
"Protein kinase",
"tr-type G",
"Radical SAM core",
Expand All @@ -87,7 +87,7 @@ class ResidueAnnotation:
ResidueAnnotation(
name="Active site",
swissprot_header="ACT_SITE",
values=[
class_names=[
"Proton acceptor",
"Proton donor",
"Nucleophile",
Expand All @@ -97,12 +97,12 @@ class ResidueAnnotation:
ResidueAnnotation(
name="Signal peptide",
swissprot_header="SIGNAL",
values=[ResidueAnnotation.ALL],
class_names=[ResidueAnnotation.ALL_CLASSES],
),
ResidueAnnotation(
name="Transit peptide",
swissprot_header="TRANSIT",
values=[ResidueAnnotation.ALL, "Mitochondrion", "Chloroplast"],
class_names=[ResidueAnnotation.ALL_CLASSES, "Mitochondrion", "Chloroplast"],
),
]

Expand All @@ -123,6 +123,56 @@ def get_sae_acts(
return sae_acts.cpu().numpy()


def get_annotation_entries_for_class(
swissprot_df: pd.DataFrame,
annotation: ResidueAnnotation,
class_name: str,
max_seqs_per_task: int,
) -> dict[str, list[dict]]:
"""
Map each sequence to a list of annotations entries like:
{
"AAA": [
{"start": 1, "end": 24, "note": "H-T-H motif"},
{"start": 100, "end": 120, "note": "Homeobox"},
],
...
}
Downsample to max_seqs_per_task if necessary.
"""
seq_to_annotation_entries = {}
for _, row in swissprot_df[swissprot_df[annotation.name].notna()].iterrows():
seq = row["Sequence"]
entries = parse_swissprot_annotation(
row[annotation.name], header=annotation.swissprot_header
)
if class_name != ResidueAnnotation.ALL_CLASSES:
# The note field is sometimes like "Homeobox", "Homeobox 1", etc.,
# so use string `in` to check.
entries = [e for e in entries if class_name in e.get("note", "")]
if len(entries) > 0:
seq_to_annotation_entries[seq] = entries
logger.info(
f"Found {len(seq_to_annotation_entries)} sequences with class {class_name}"
)

if len(seq_to_annotation_entries) > max_seqs_per_task:
logger.warning(
f"Since max_seqs_per_task={max_seqs_per_task}, using a random "
f"sample of {max_seqs_per_task} sequences."
)
subset_seqs = random.sample(
list(seq_to_annotation_entries.keys()), max_seqs_per_task
)
seq_to_annotation_entries = {
seq: entries
for seq, entries in seq_to_annotation_entries.items()
if seq in subset_seqs
}

return seq_to_annotation_entries


def make_examples_from_annotation_entries(
seq_to_annotation_entries: dict[str, list[dict]],
tokenizer: AutoTokenizer,
Expand All @@ -144,14 +194,15 @@ def make_examples_from_annotation_entries(
Create an example for each residue in each sequence where:
Input: SAE activation at the residue position
Target: Boolean indicating whether the residue is annotated with
the label, e.g. whether it falls within the motif labeled "H-T-H motif".
Target: Boolean indicating whether the residue has an annotation with
of given class, e.g. whether it falls within a motif of class
"H-T-H motif".
Returns a list of dicts like:
```
[
{
"sae_acts": [0.1, 0.2, 0.3, ...], # A number for each hidden dim
"sae_acts": [0.1, 0.2, 0.3, ...], # A number for each latent
"target": True,
},
{
Expand Down Expand Up @@ -262,49 +313,15 @@ def latent_probe(
logger.info(f"Processing annotation: {annotation.name}")
os.makedirs(os.path.join(output_dir, annotation.name), exist_ok=True)

for label in annotation.values:
output_path = os.path.join(output_dir, annotation.name, f"{label}.csv")
for class_name in annotation.class_names:
output_path = os.path.join(output_dir, annotation.name, f"{class_name}.csv")
if os.path.exists(output_path):
logger.warning(f"Skipping {output_path} because it already exists")
continue

# First, map each sequence to a list of annotations entries like:
# {
# "AAA": [
# {"start": 1, "end": 24, "note": "H-T-H motif"},
# {"start": 100, "end": 120, "note": "Homeobox"},
# ],
# }
seq_to_annotation_entries = {}
for _, row in df[df[annotation.name].notna()].iterrows():
seq = row["Sequence"]
entries = parse_swissprot_annotation(
row[annotation.name], header=annotation.swissprot_header
)
if label != ResidueAnnotation.ALL:
# The note field is sometimes like "Homeobox", "Homeobox 1", etc.,
# so use string `in` to check.
entries = [e for e in entries if label in e.get("note", "")]
if len(entries) > 0:
seq_to_annotation_entries[seq] = entries
logger.info(
f"Found {len(seq_to_annotation_entries)} sequences with label {label}"
seq_to_annotation_entries = get_annotation_entries_for_class(
df, annotation, class_name, max_seqs_per_task
)

if len(seq_to_annotation_entries) > max_seqs_per_task:
logger.warning(
f"Since max_seqs_per_task={max_seqs_per_task}, using a random "
f"sample of {max_seqs_per_task} sequences."
)
subset_seqs = random.sample(
list(seq_to_annotation_entries.keys()), max_seqs_per_task
)
seq_to_annotation_entries = {
seq: entries
for seq, entries in seq_to_annotation_entries.items()
if seq in subset_seqs
}

examples = make_examples_from_annotation_entries(
seq_to_annotation_entries=seq_to_annotation_entries,
tokenizer=tokenizer,
Expand All @@ -313,20 +330,25 @@ def latent_probe(
plm_layer=plm_layer,
)

# Run logistic regression for each dimension where the input is a number
# – the SAE activation of a fixed dimension at a fixed position – and
# the target is the binary target.
train_examples, test_examples = train_test_split(
examples,
test_size=0.1,
random_state=42,
stratify=[e["target"] for e in examples],
)

with warnings.catch_warnings():
# LogisticRegression throws warnings when it can't converge.
# This is expected for most dimensions.
warnings.simplefilter("ignore")

res_rows = []
for dim in range(sae_dim):
for dim in tqdm(
range(sae_dim),
desc="Logistic regression on each latent dimension",
):
model = LogisticRegression(class_weight="balanced")
X_train = [[e["sae_acts"][dim]] for e in train_examples]
y_train = [e["target"] for e in train_examples]
Expand Down
56 changes: 56 additions & 0 deletions tests/test_latent_probe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import unittest
from unittest.mock import Mock, patch

import pandas as pd

from plm_interpretability.latent_probe.__main__ import (
ResidueAnnotation,
get_annotation_entries_for_class,
make_examples_from_annotation_entries,
)

Expand Down Expand Up @@ -78,3 +82,55 @@ def test_make_examples_from_annotation_entries(self, mock_get_sae_acts):
sae_model=mock_sae_model,
plm_layer=24,
)

def test_get_annotation_entries_for_class(self):
mock_df = pd.DataFrame(
{
"Sequence": ["ABCDEF", "GHIJKL", "MNOPQR"],
"DNA binding": [
'DNA_BIND 1..3; /note="H-T-H motif"',
'DNA_BIND 2..4; /note="Homeobox"',
'DNA_BIND 1..6; /note="Nuclear receptor"',
],
}
)

annotation = ResidueAnnotation(
name="DNA binding",
swissprot_header="DNA_BIND",
class_names=["H-T-H motif", "Homeobox", "Nuclear receptor"],
)

result = get_annotation_entries_for_class(
mock_df, annotation, "H-T-H motif", max_seqs_per_task=10
)
self.assertEqual(len(result), 1)
self.assertIn("ABCDEF", result)
self.assertEqual(
result["ABCDEF"], [{"start": 1, "end": 3, "note": "H-T-H motif"}]
)

result = get_annotation_entries_for_class(
mock_df, annotation, "Homeobox", max_seqs_per_task=10
)
self.assertEqual(len(result), 1)
self.assertIn("GHIJKL", result)
self.assertEqual(result["GHIJKL"], [{"start": 2, "end": 4, "note": "Homeobox"}])

result = get_annotation_entries_for_class(
mock_df, annotation, ResidueAnnotation.ALL_CLASSES, max_seqs_per_task=10
)
self.assertEqual(len(result), 3)
self.assertIn("ABCDEF", result)
self.assertIn("GHIJKL", result)
self.assertIn("MNOPQR", result)

result = get_annotation_entries_for_class(
mock_df, annotation, ResidueAnnotation.ALL_CLASSES, max_seqs_per_task=2
)
self.assertEqual(len(result), 2)

result = get_annotation_entries_for_class(
mock_df, annotation, "Non-existent", max_seqs_per_task=10
)
self.assertEqual(len(result), 0)

0 comments on commit 140cf79

Please sign in to comment.