Skip to content

Commit

Permalink
F
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Oct 15, 2024
1 parent ec78ba9 commit 8ce29b1
Show file tree
Hide file tree
Showing 11 changed files with 527 additions and 333 deletions.
5 changes: 0 additions & 5 deletions plm_interpretability/latent_probe/README.md

This file was deleted.

13 changes: 13 additions & 0 deletions plm_interpretability/logistic_regression_probe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Binary classification probes

## Single latent

```bash
logistic_regression_probe single_latent --sae-checkpoint plm_interpretability/checkpoints/l24_plm1280_sae4096_k128_100k.pt --sae-dim 4096 --plm-dim 1280 --plm-layer 24 --swissprot-tsv plm_interpretability/logistic_regression_probe/data/swissprot.tsv --output-dir plm_interpretability/logistic_regression_probe/results
```

## All latents

```bash
logistic_regression_probe all_latents --sae-checkpoint plm_interpretability/checkpoints/l24_plm1280_sae4096_k128_100k.pt --sae-dim 4096 --plm-dim 1280 --plm-layer 24 --swissprot-tsv plm_interpretability/logistic_regression_probe/data/swissprot.tsv --output-dir plm_interpretability/logistic_regression_probe/results
```
17 changes: 17 additions & 0 deletions plm_interpretability/logistic_regression_probe/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import click

from plm_interpretability.logistic_regression_probe.all_latents import all_latents
from plm_interpretability.logistic_regression_probe.single_latent import single_latent


@click.group()
def cli():
"""A tool for running logistic regression probes on SAE latents"""
pass


cli.add_command(single_latent)
cli.add_command(all_latents)

if __name__ == "__main__":
cli()
137 changes: 137 additions & 0 deletions plm_interpretability/logistic_regression_probe/all_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import gc
import warnings

import click
import numpy as np
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, EsmModel

from plm_interpretability.logistic_regression_probe.annotations import RESIDUE_ANNOTATIONS
from plm_interpretability.logistic_regression_probe.logging import logger
from plm_interpretability.logistic_regression_probe.utils import (
get_annotation_entries_for_class,
make_examples_from_annotation_entries,
)
from plm_interpretability.sae_model import SparseAutoencoder


@click.command()
@click.option(
"--sae-checkpoint",
type=click.Path(exists=True),
required=True,
help="Path to the SAE checkpoint file",
)
@click.option("--sae-dim", type=int, required=True, help="Dimension of the sparse autoencoder")
@click.option("--plm-dim", type=int, required=True, help="Dimension of the protein language model")
@click.option(
"--plm-layer",
type=int,
required=True,
help="Layer of the protein language model to use",
)
@click.option(
"--swissprot-tsv",
type=click.Path(exists=True),
required=True,
help="Path to the SwissProt TSV file",
)
@click.option(
"--output-file",
type=click.File(),
required=True,
help="Path to the output file",
)
@click.option(
"--annotation-names",
type=click.STRING,
multiple=True,
help="List of annotation names to process. If not provided, all annotations will be processed.",
)
@click.option(
"--max-seqs-per-task",
type=int,
default=1000,
help="Maximum number of sequences to use for a given logistic regression task",
)
def all_latents(
sae_checkpoint: str,
sae_dim: int,
plm_dim: int,
plm_layer: int,
swissprot_tsv: str,
output_file: str,
annotation_names: list[str],
max_seqs_per_task: int,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.debug(f"Using device: {device}")

# Load pLM and SAE
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
plm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval()
sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device)
sae_model.load_state_dict(torch.load(sae_checkpoint, map_location=device))

df = pd.read_csv(swissprot_tsv, sep="\t")

res_rows = []
for annotation in RESIDUE_ANNOTATIONS:
if annotation_names and annotation.name not in annotation_names:
continue

logger.info(f"Processing annotation: {annotation.name}")

for class_name in annotation.class_names:
seq_to_annotation_entries = get_annotation_entries_for_class(
df, annotation, class_name, max_seqs_per_task
)
examples = make_examples_from_annotation_entries(
seq_to_annotation_entries=seq_to_annotation_entries,
tokenizer=tokenizer,
plm_model=plm_model,
sae_model=sae_model,
plm_layer=plm_layer,
)
train_examples, test_examples = train_test_split(
examples,
test_size=0.1,
random_state=42,
stratify=[e["target"] for e in examples],
)
with warnings.catch_warnings():
# LogisticRegression throws warnings when it can't converge.
# This is expected for most dimensions.
warnings.simplefilter("ignore")

X_train = np.array([e["sae_acts"] for e in train_examples], dtype="float32")
y_train = np.array([e["target"] for e in train_examples], dtype="bool")
X_test = np.array([e["sae_acts"] for e in test_examples], dtype="float32")
y_test = np.array([e["target"] for e in test_examples], dtype="bool")

model = LogisticRegression(class_weight="balanced")
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

logger.info(f"Results: {precision}, {recall}, {f1}")
res_rows.append((annotation.name, class_name, precision, recall, f1))

res_df = pd.DataFrame(
res_rows, columns=["annotation", "class", "precision", "recall", "f1"]
)
res_df.to_csv(output_file, index=False)
logger.info(f"Results saved to {output_file}")

del seq_to_annotation_entries, examples, train_examples, test_examples
gc.collect()


if __name__ == "__main__":
all_latents()
84 changes: 84 additions & 0 deletions plm_interpretability/logistic_regression_probe/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from dataclasses import dataclass


@dataclass
class ResidueAnnotation:
name: str
swissprot_header: str
class_names: list[str]

# Class name used to indicate that we don't care about the annotation class,
# as long as the annotation exists. E.g. signal peptides annotations look like
# `{'start': 1, 'end': 24, 'evidence': 'ECO:0000255'}`, so we just classify
# whether the residue is part of a signal peptide or not.
ALL_CLASSES = "all"


RESIDUE_ANNOTATIONS = [
ResidueAnnotation(
name="DNA binding",
swissprot_header="DNA_BIND",
class_names=["H-T-H motif", "Homeobox", "Nuclear receptor", "HMG box"],
),
ResidueAnnotation(
name="Motif",
swissprot_header="MOTIF",
class_names=[
"Nuclear localization signal",
"Nuclear export signal",
"DEAD box",
"Cell attachment site",
"JAMM motif",
"SH3-binding",
"Cysteine switch",
],
),
ResidueAnnotation(
name="Topological domain",
swissprot_header="TOPO_DOM",
class_names=[
"Cytoplasmic",
"Extracellular",
"Lumenal",
"Periplasmic",
"Mitochondrial intermembrane",
"Mitochondrial matrix",
"Virion surface",
"Intravirion",
],
),
ResidueAnnotation(
name="Domain [FT]",
swissprot_header="DOMAIN",
class_names=[
"Protein kinase",
"tr-type G",
"Radical SAM core",
"ABC transporter",
"Helicase ATP-binding",
"Glutamine amidotransferase type-1",
"ATP-grasp",
"S4 RNA-binding",
],
),
ResidueAnnotation(
name="Active site",
swissprot_header="ACT_SITE",
class_names=[
"Proton acceptor",
"Proton donor",
"Nucleophile",
"Charge relay system",
],
),
ResidueAnnotation(
name="Signal peptide",
swissprot_header="SIGNAL",
class_names=[ResidueAnnotation.ALL_CLASSES],
),
ResidueAnnotation(
name="Transit peptide",
swissprot_header="TRANSIT",
class_names=[ResidueAnnotation.ALL_CLASSES, "Mitochondrion", "Chloroplast"],
),
]
9 changes: 9 additions & 0 deletions plm_interpretability/logistic_regression_probe/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import logging


def configure_logging():
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
return logging.getLogger(__name__)


logger = configure_logging()
Loading

0 comments on commit 8ce29b1

Please sign in to comment.