-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
527 additions
and
333 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Binary classification probes | ||
|
||
## Single latent | ||
|
||
```bash | ||
logistic_regression_probe single_latent --sae-checkpoint plm_interpretability/checkpoints/l24_plm1280_sae4096_k128_100k.pt --sae-dim 4096 --plm-dim 1280 --plm-layer 24 --swissprot-tsv plm_interpretability/logistic_regression_probe/data/swissprot.tsv --output-dir plm_interpretability/logistic_regression_probe/results | ||
``` | ||
|
||
## All latents | ||
|
||
```bash | ||
logistic_regression_probe all_latents --sae-checkpoint plm_interpretability/checkpoints/l24_plm1280_sae4096_k128_100k.pt --sae-dim 4096 --plm-dim 1280 --plm-layer 24 --swissprot-tsv plm_interpretability/logistic_regression_probe/data/swissprot.tsv --output-dir plm_interpretability/logistic_regression_probe/results | ||
``` |
17 changes: 17 additions & 0 deletions
17
plm_interpretability/logistic_regression_probe/__main__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import click | ||
|
||
from plm_interpretability.logistic_regression_probe.all_latents import all_latents | ||
from plm_interpretability.logistic_regression_probe.single_latent import single_latent | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
"""A tool for running logistic regression probes on SAE latents""" | ||
pass | ||
|
||
|
||
cli.add_command(single_latent) | ||
cli.add_command(all_latents) | ||
|
||
if __name__ == "__main__": | ||
cli() |
137 changes: 137 additions & 0 deletions
137
plm_interpretability/logistic_regression_probe/all_latents.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import gc | ||
import warnings | ||
|
||
import click | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.metrics import f1_score, precision_score, recall_score | ||
from sklearn.model_selection import train_test_split | ||
from transformers import AutoTokenizer, EsmModel | ||
|
||
from plm_interpretability.logistic_regression_probe.annotations import RESIDUE_ANNOTATIONS | ||
from plm_interpretability.logistic_regression_probe.logging import logger | ||
from plm_interpretability.logistic_regression_probe.utils import ( | ||
get_annotation_entries_for_class, | ||
make_examples_from_annotation_entries, | ||
) | ||
from plm_interpretability.sae_model import SparseAutoencoder | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--sae-checkpoint", | ||
type=click.Path(exists=True), | ||
required=True, | ||
help="Path to the SAE checkpoint file", | ||
) | ||
@click.option("--sae-dim", type=int, required=True, help="Dimension of the sparse autoencoder") | ||
@click.option("--plm-dim", type=int, required=True, help="Dimension of the protein language model") | ||
@click.option( | ||
"--plm-layer", | ||
type=int, | ||
required=True, | ||
help="Layer of the protein language model to use", | ||
) | ||
@click.option( | ||
"--swissprot-tsv", | ||
type=click.Path(exists=True), | ||
required=True, | ||
help="Path to the SwissProt TSV file", | ||
) | ||
@click.option( | ||
"--output-file", | ||
type=click.File(), | ||
required=True, | ||
help="Path to the output file", | ||
) | ||
@click.option( | ||
"--annotation-names", | ||
type=click.STRING, | ||
multiple=True, | ||
help="List of annotation names to process. If not provided, all annotations will be processed.", | ||
) | ||
@click.option( | ||
"--max-seqs-per-task", | ||
type=int, | ||
default=1000, | ||
help="Maximum number of sequences to use for a given logistic regression task", | ||
) | ||
def all_latents( | ||
sae_checkpoint: str, | ||
sae_dim: int, | ||
plm_dim: int, | ||
plm_layer: int, | ||
swissprot_tsv: str, | ||
output_file: str, | ||
annotation_names: list[str], | ||
max_seqs_per_task: int, | ||
): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
logger.debug(f"Using device: {device}") | ||
|
||
# Load pLM and SAE | ||
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | ||
plm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval() | ||
sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device) | ||
sae_model.load_state_dict(torch.load(sae_checkpoint, map_location=device)) | ||
|
||
df = pd.read_csv(swissprot_tsv, sep="\t") | ||
|
||
res_rows = [] | ||
for annotation in RESIDUE_ANNOTATIONS: | ||
if annotation_names and annotation.name not in annotation_names: | ||
continue | ||
|
||
logger.info(f"Processing annotation: {annotation.name}") | ||
|
||
for class_name in annotation.class_names: | ||
seq_to_annotation_entries = get_annotation_entries_for_class( | ||
df, annotation, class_name, max_seqs_per_task | ||
) | ||
examples = make_examples_from_annotation_entries( | ||
seq_to_annotation_entries=seq_to_annotation_entries, | ||
tokenizer=tokenizer, | ||
plm_model=plm_model, | ||
sae_model=sae_model, | ||
plm_layer=plm_layer, | ||
) | ||
train_examples, test_examples = train_test_split( | ||
examples, | ||
test_size=0.1, | ||
random_state=42, | ||
stratify=[e["target"] for e in examples], | ||
) | ||
with warnings.catch_warnings(): | ||
# LogisticRegression throws warnings when it can't converge. | ||
# This is expected for most dimensions. | ||
warnings.simplefilter("ignore") | ||
|
||
X_train = np.array([e["sae_acts"] for e in train_examples], dtype="float32") | ||
y_train = np.array([e["target"] for e in train_examples], dtype="bool") | ||
X_test = np.array([e["sae_acts"] for e in test_examples], dtype="float32") | ||
y_test = np.array([e["target"] for e in test_examples], dtype="bool") | ||
|
||
model = LogisticRegression(class_weight="balanced") | ||
model.fit(X_train, y_train) | ||
y_pred = model.predict(X_test) | ||
precision = precision_score(y_test, y_pred) | ||
recall = recall_score(y_test, y_pred) | ||
f1 = f1_score(y_test, y_pred) | ||
|
||
logger.info(f"Results: {precision}, {recall}, {f1}") | ||
res_rows.append((annotation.name, class_name, precision, recall, f1)) | ||
|
||
res_df = pd.DataFrame( | ||
res_rows, columns=["annotation", "class", "precision", "recall", "f1"] | ||
) | ||
res_df.to_csv(output_file, index=False) | ||
logger.info(f"Results saved to {output_file}") | ||
|
||
del seq_to_annotation_entries, examples, train_examples, test_examples | ||
gc.collect() | ||
|
||
|
||
if __name__ == "__main__": | ||
all_latents() |
84 changes: 84 additions & 0 deletions
84
plm_interpretability/logistic_regression_probe/annotations.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class ResidueAnnotation: | ||
name: str | ||
swissprot_header: str | ||
class_names: list[str] | ||
|
||
# Class name used to indicate that we don't care about the annotation class, | ||
# as long as the annotation exists. E.g. signal peptides annotations look like | ||
# `{'start': 1, 'end': 24, 'evidence': 'ECO:0000255'}`, so we just classify | ||
# whether the residue is part of a signal peptide or not. | ||
ALL_CLASSES = "all" | ||
|
||
|
||
RESIDUE_ANNOTATIONS = [ | ||
ResidueAnnotation( | ||
name="DNA binding", | ||
swissprot_header="DNA_BIND", | ||
class_names=["H-T-H motif", "Homeobox", "Nuclear receptor", "HMG box"], | ||
), | ||
ResidueAnnotation( | ||
name="Motif", | ||
swissprot_header="MOTIF", | ||
class_names=[ | ||
"Nuclear localization signal", | ||
"Nuclear export signal", | ||
"DEAD box", | ||
"Cell attachment site", | ||
"JAMM motif", | ||
"SH3-binding", | ||
"Cysteine switch", | ||
], | ||
), | ||
ResidueAnnotation( | ||
name="Topological domain", | ||
swissprot_header="TOPO_DOM", | ||
class_names=[ | ||
"Cytoplasmic", | ||
"Extracellular", | ||
"Lumenal", | ||
"Periplasmic", | ||
"Mitochondrial intermembrane", | ||
"Mitochondrial matrix", | ||
"Virion surface", | ||
"Intravirion", | ||
], | ||
), | ||
ResidueAnnotation( | ||
name="Domain [FT]", | ||
swissprot_header="DOMAIN", | ||
class_names=[ | ||
"Protein kinase", | ||
"tr-type G", | ||
"Radical SAM core", | ||
"ABC transporter", | ||
"Helicase ATP-binding", | ||
"Glutamine amidotransferase type-1", | ||
"ATP-grasp", | ||
"S4 RNA-binding", | ||
], | ||
), | ||
ResidueAnnotation( | ||
name="Active site", | ||
swissprot_header="ACT_SITE", | ||
class_names=[ | ||
"Proton acceptor", | ||
"Proton donor", | ||
"Nucleophile", | ||
"Charge relay system", | ||
], | ||
), | ||
ResidueAnnotation( | ||
name="Signal peptide", | ||
swissprot_header="SIGNAL", | ||
class_names=[ResidueAnnotation.ALL_CLASSES], | ||
), | ||
ResidueAnnotation( | ||
name="Transit peptide", | ||
swissprot_header="TRANSIT", | ||
class_names=[ResidueAnnotation.ALL_CLASSES, "Mitochondrion", "Chloroplast"], | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import logging | ||
|
||
|
||
def configure_logging(): | ||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | ||
return logging.getLogger(__name__) | ||
|
||
|
||
logger = configure_logging() |
Oops, something went wrong.