Skip to content

Commit

Permalink
fix: rename to interprot
Browse files Browse the repository at this point in the history
  • Loading branch information
etowahadams committed Oct 31, 2024
1 parent 0c5c554 commit e9fb86b
Show file tree
Hide file tree
Showing 48 changed files with 119 additions and 119 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pLM Interpretability
# InterProt

This repo contains tools used to interpret protein language models. `viz` contains the frontend app for visualizing SAE features. `plm_interpretability` is a python package containing tools for SAE training and interpretation.
This repo contains tools used to interpret protein language models. `viz` contains the frontend app for visualizing SAE features. `interprot` is a python package containing tools for SAE training and interpretation.

## The visualizer

Expand All @@ -16,16 +16,16 @@ pnpm run dev

```bash
docker compose build
docker compose run --rm plm-interpretability bash
docker compose run --rm interprot bash
pytest
```

### Running commands

Each directory under `plm_interpretability` contains a command-line tool. For example, `make_viz_files` takes in an SAE checkpoint and generates JSON files containing SAE activations used to serve the visualizer. You can run it with
Each directory under `interprot` contains a command-line tool. For example, `make_viz_files` takes in an SAE checkpoint and generates JSON files containing SAE activations used to serve the visualizer. You can run it with

```bash
cd plm_interpretability
cd interprot
python -m make_viz_files \
--checkpoint-files <path to checkpoint> \
--output-dir <path to output directory where the JSON files will be saved>
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
services:
plm-interpretability:
interprot:
build: .
volumes:
- .:/app
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Use the following command to produce a CSV file matching a desired secondary str
by the regex `E{3,12}[T]{2,5}E{3,12}`.

```bash
autointerp pdb2labels --dssp-file plm_interpretability/autointerp/data/ss.txt --ss-patterns "E{3,12}[T]{2,5}E{3,12}" --out-path "plm_interpretability/autointerp/results/labels/E{3,12}[T]{2,5}E{3,12}_labels.csv"
autointerp pdb2labels --dssp-file interprot/autointerp/data/ss.txt --ss-patterns "E{3,12}[T]{2,5}E{3,12}" --out-path "interprot/autointerp/results/labels/E{3,12}[T]{2,5}E{3,12}_labels.csv"
```

### Step 2: Produce a CSV file that scores each SAE dimension on its ability to discriminate against the label

```bash
autointerp labels2latents --labels-csv "plm_interpretability/autointerp/results/labels/E{3,12}[T]{2,5}E{3,12}_labels.csv" --sae-checkpoint plm_interpretability/checkpoints/l24_plm1280_sae4096_k128_211k.pt --plm-dim 1280 --plm-layer 24 --sae-dim 4096 --out-path "plm_interpretability/autointerp/results/l24_plm1280_sae4096_k128_211k/E{3,12}[T]{2,5}E{3,12}_mapping.csv"
autointerp labels2latents --labels-csv "interprot/autointerp/results/labels/E{3,12}[T]{2,5}E{3,12}_labels.csv" --sae-checkpoint interprot/checkpoints/l24_plm1280_sae4096_k128_211k.pt --plm-dim 1280 --plm-layer 24 --sae-dim 4096 --out-path "interprot/autointerp/results/l24_plm1280_sae4096_k128_211k/E{3,12}[T]{2,5}E{3,12}_mapping.csv"
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import click

from plm_interpretability.autointerp.labels2latents import labels2latents
from plm_interpretability.autointerp.pdb2labels import pdb2labels
from interprot.autointerp.labels2latents import labels2latents
from interprot.autointerp.pdb2labels import pdb2labels


@click.group()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel

from plm_interpretability.sae_model import SparseAutoencoder
from plm_interpretability.utils import get_layer_activations
from interprot.sae_model import SparseAutoencoder
from interprot.utils import get_layer_activations


def compute_scores_matrix(
Expand Down
File renamed without changes.
30 changes: 30 additions & 0 deletions interprot/autointerp/run_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Convenience script to run all autointerp experiments.

PLM_DIM=1280
PLM_LAYER=24

for motif in "E{3,12}[T]{2,5}E{3,12}" "H{4,40}[TS]{1,12}H{4,40}"; do
autointerp labels2latents \
--labels-csv "interprot/autointerp/results/labels/${motif}_labels.csv" \
--sae-checkpoint "interprot/checkpoints/l${PLM_LAYER}_plm${PLM_DIM}_sae4096_k128_100k.pt" \
--plm-dim $PLM_DIM \
--plm-layer $PLM_LAYER \
--sae-dim 4096 \
--out-path "interprot/autointerp/results/l${PLM_LAYER}_plm${PLM_DIM}_sae4096_k128_100k/${motif}_mapping.csv"

autointerp labels2latents \
--labels-csv "interprot/autointerp/results/labels/${motif}_labels.csv" \
--sae-checkpoint "interprot/checkpoints/l${PLM_LAYER}_plm${PLM_DIM}_sae4096_k128_211k.pt" \
--plm-dim $PLM_DIM \
--plm-layer $PLM_LAYER \
--sae-dim 4096 \
--out-path "interprot/autointerp/results/l${PLM_LAYER}_plm${PLM_DIM}_sae4096_k128_211k/${motif}_mapping.csv"

autointerp labels2latents \
--labels-csv "interprot/autointerp/results/labels/${motif}_labels.csv" \
--sae-checkpoint "interprot/checkpoints/l${PLM_LAYER}_plm${PLM_DIM}_sae32768_k128_100k.pt" \
--plm-dim $PLM_DIM \
--plm-layer $PLM_LAYER \
--sae-dim 32768 \
--out-path "interprot/autointerp/results/l${PLM_LAYER}_plm${PLM_DIM}_sae32768_k128_100k/${motif}_mapping.csv"
done
File renamed without changes.
File renamed without changes.
29 changes: 29 additions & 0 deletions interprot/logistic_regression_probe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Binary classification probes

## Single latent

```bash
logistic_regression_probe single-latent \
--sae-checkpoint interprot/checkpoints/l24_plm1280_sae4096_k128_100k.pt \
--sae-dim 4096 \
--plm-dim 1280 \
--plm-layer 24 \
--swissprot-tsv interprot/logistic_regression_probe/data/swissprot.tsv \
--output-dir interprot/logistic_regression_probe/results \
--max-seqs-per-task 5 \
--annotation-names "DNA binding"
```

## All latents

```bash
logistic_regression_probe all-latents \
--sae-checkpoint interprot/checkpoints/l24_plm1280_sae4096_k128_100k.pt \
--sae-dim 4096 \
--plm-dim 1280 \
--plm-layer 24 \
--swissprot-tsv interprot/logistic_regression_probe/data/swissprot.tsv \
--output-file interprot/logistic_regression_probe/results/all_latents.csv \
--max-seqs-per-task 5 \
--annotation-names "DNA binding"
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import click

from plm_interpretability.logistic_regression_probe.all_latents import all_latents
from plm_interpretability.logistic_regression_probe.single_latent import single_latent
from interprot.logistic_regression_probe.all_latents import all_latents
from interprot.logistic_regression_probe.single_latent import single_latent


@click.group()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import AutoTokenizer, EsmModel

from plm_interpretability.logistic_regression_probe.annotations import (
from interprot.logistic_regression_probe.annotations import (
RESIDUE_ANNOTATION_NAMES,
RESIDUE_ANNOTATIONS,
)
from plm_interpretability.logistic_regression_probe.logging import logger
from plm_interpretability.logistic_regression_probe.utils import (
from interprot.logistic_regression_probe.logging import logger
from interprot.logistic_regression_probe.utils import (
prepare_arrays_for_logistic_regression,
)
from plm_interpretability.sae_model import SparseAutoencoder
from interprot.sae_model import SparseAutoencoder


@click.command()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel

from plm_interpretability.logistic_regression_probe.annotations import (
from interprot.logistic_regression_probe.annotations import (
RESIDUE_ANNOTATION_NAMES,
RESIDUE_ANNOTATIONS,
)
from plm_interpretability.logistic_regression_probe.logging import logger
from plm_interpretability.logistic_regression_probe.utils import (
from interprot.logistic_regression_probe.logging import logger
from interprot.logistic_regression_probe.utils import (
prepare_arrays_for_logistic_regression,
)
from plm_interpretability.sae_model import SparseAutoencoder
from interprot.sae_model import SparseAutoencoder


def augment_df_with_aa_identity(df: pd.DataFrame) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel

from plm_interpretability.logistic_regression_probe.annotations import ResidueAnnotation
from plm_interpretability.logistic_regression_probe.logging import logger
from plm_interpretability.sae_model import SparseAutoencoder
from plm_interpretability.utils import get_layer_activations, parse_swissprot_annotation
from interprot.logistic_regression_probe.annotations import ResidueAnnotation
from interprot.logistic_regression_probe.logging import logger
from interprot.sae_model import SparseAutoencoder
from interprot.utils import get_layer_activations, parse_swissprot_annotation

MAX_SEQ_LEN = 1000

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel

from plm_interpretability.sae_model import SparseAutoencoder
from plm_interpretability.utils import get_layer_activations
from interprot.sae_model import SparseAutoencoder
from interprot.utils import get_layer_activations

OUTPUT_ROOT_DIR = "viz_files"
NUM_SEQS_PER_DIM = 12
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
30 changes: 0 additions & 30 deletions plm_interpretability/autointerp/run_all.sh

This file was deleted.

29 changes: 0 additions & 29 deletions plm_interpretability/logistic_regression_probe/README.md

This file was deleted.

24 changes: 12 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[project]
name = "plm-interpretability"
name = "interprot"
version = "0.1.0"
description = "Mechanistic interpretability tools for protein language models"
authors = [
Expand All @@ -21,19 +21,19 @@ readme = "README.md"

[tool.setuptools]
packages = [
"plm_interpretability",
"plm_interpretability.autointerp",
"plm_interpretability.logistic_regression_probe",
"plm_interpretability.make_viz_files",
"interprot",
"interprot.autointerp",
"interprot.logistic_regression_probe",
"interprot.make_viz_files",
]

[project.scripts]
autointerp = "plm_interpretability.autointerp.__main__:cli"
logistic_regression_probe = "plm_interpretability.logistic_regression_probe.__main__:cli"
make_viz_files = "plm_interpretability.make_viz_files.__main__:make_viz_files"
autointerp = "interprot.autointerp.__main__:cli"
logistic_regression_probe = "interprot.logistic_regression_probe.__main__:cli"
make_viz_files = "interprot.make_viz_files.__main__:make_viz_files"

[project.urls]
Homepage = "https://github.com/etowahadams/plm-interpretability"
Homepage = "https://github.com/etowahadams/interprot"

[build-system]
requires = ["hatchling"]
Expand Down Expand Up @@ -61,7 +61,7 @@ dev = [

[tool.hatch.build.targets.wheel]
packages = [
"plm_interpretability",
"plm_interpretability.autointerp",
"plm_interpretability.logistic_regression_probe"
"interprot",
"interprot.autointerp",
"interprot.logistic_regression_probe"
]
18 changes: 9 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ certifi==2024.8.30
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
filelock==3.16.0
# via
# huggingface-hub
Expand Down Expand Up @@ -40,7 +40,7 @@ networkx==3.3
numpy==1.26.4
# via
# pandas
# plm-interpretability (pyproject.toml)
# interprot (pyproject.toml)
# scikit-learn
# scipy
# transformers
Expand All @@ -50,13 +50,13 @@ packaging==24.1
# pytest
# transformers
pandas==2.2.3
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
pluggy==1.5.0
# via pytest
polars==1.7.1
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
pytest==8.3.3
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
python-dateutil==2.9.0.post0
# via pandas
pytz==2024.2
Expand All @@ -74,7 +74,7 @@ requests==2.32.3
safetensors==0.4.5
# via transformers
scikit-learn==1.5.2
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
scipy==1.14.1
# via scikit-learn
six==1.16.0
Expand All @@ -86,14 +86,14 @@ threadpoolctl==3.5.0
tokenizers==0.19.1
# via transformers
torch==2.2.2
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
tqdm==4.66.5
# via
# huggingface-hub
# plm-interpretability (pyproject.toml)
# interprot (pyproject.toml)
# transformers
transformers==4.44.2
# via plm-interpretability (pyproject.toml)
# via interprot (pyproject.toml)
typing-extensions==4.12.2
# via
# huggingface-hub
Expand Down
Loading

0 comments on commit e9fb86b

Please sign in to comment.