diff --git a/plm_interpretability/data_module.py b/plm_interpretability/data_module.py new file mode 100644 index 0000000..7b51f7a --- /dev/null +++ b/plm_interpretability/data_module.py @@ -0,0 +1,36 @@ +import torch +import polars as pr +import pytorch_lightning as pl +from torch.utils.data import Dataset +from utils import train_val_test_split + +class PolarsDataset(Dataset): + def __init__(self, df): + self.df = df + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.row(idx, named=True) + return {"Sequence": row["Sequence"], "Entry": row["Entry"]} + +# Data Module +class SequenceDataModule(pl.LightningDataModule): + def __init__(self, data_path, batch_size): + super().__init__() + self.data_path = data_path + self.batch_size = batch_size + + def setup(self, stage=None): + df = pr.read_parquet(self.data_path) + self.train_data, self.val_data, self.test_data = train_val_test_split(df) + + def train_dataloader(self): + return torch.utils.data.DataLoader(PolarsDataset(self.train_data), batch_size=self.batch_size, shuffle=True) + + def val_dataloader(self): + return torch.utils.data.DataLoader(PolarsDataset(self.val_data), batch_size=self.batch_size) + + def test_dataloader(self): + return torch.utils.data.DataLoader(PolarsDataset(self.test_data), batch_size=self.batch_size) diff --git a/plm_interpretability/esm_wrapper.py b/plm_interpretability/esm_wrapper.py new file mode 100644 index 0000000..2e3a690 --- /dev/null +++ b/plm_interpretability/esm_wrapper.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import pytorch_lightning as pl +from esm.modules import ESM1bLayerNorm, RobertaLMHead, TransformerLayer + +class ESM2Model(pl.LightningModule): + def __init__(self, num_layers, embed_dim, attention_heads, alphabet, token_dropout): + super().__init__() + self.num_layers = num_layers + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.alphabet = alphabet + self.alphabet_size = len(alphabet) + self.batch_converter = self.alphabet.get_batch_converter() + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.token_dropout = token_dropout + self._init_submodules() + + def _init_submodules(self): + self.embed_scale = 1 + self.embed_tokens = nn.Embedding( + self.alphabet_size, + self.embed_dim, + padding_idx=self.padding_idx, + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + self.embed_dim, + 4 * self.embed_dim, + self.attention_heads, + add_bias_kv=False, + use_esm1b_layer_norm=True, + use_rotary_embeddings=True, + ) + for _ in range(self.num_layers) + ] + ) + + self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) + + self.lm_head = RobertaLMHead( + embed_dim=self.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) + + def load_esm_ckpt(self, esm_pretrained): + ckpt = {} + model_data = torch.load(esm_pretrained)["model"] + for k in model_data: + if 'lm_head' in k: + ckpt[k.replace('encoder.','')] = model_data[k] + else: + ckpt[k.replace('encoder.sentence_encoder.','')] = model_data[k] + self.load_state_dict(ckpt) + + def compose_input(self, list_tuple_seq): + _, _, batch_tokens = self.batch_converter(list_tuple_seq) + batch_tokens = batch_tokens.to(self.device) + return batch_tokens + + def get_layer_activations(self, input, layer_idx): + if isinstance(input, str): + tokens = self.compose_input([('protein', input)]) + elif isinstance(input, list): + tokens = self.compose_input([('protein', seq) for seq in input]) + else: + tokens = input + + x = self.embed_scale * self.embed_tokens(tokens) + x = x.transpose(0, 1) # (B, T, E) => (T, B, E) + for _, layer in enumerate(self.layers[:layer_idx]): + x, attn = layer( + x, + self_attn_padding_mask=None, + need_head_weights=False, + ) + return tokens, x.transpose(0, 1) + + def get_sequence(self, x, layer_idx): + x = x.transpose(0, 1) # (B, T, E) => (T, B, E) + for _, layer in enumerate(self.layers[layer_idx:]): + x, attn = layer( + x, + self_attn_padding_mask=None, + need_head_weights=False, + ) + x = self.emb_layer_norm_after(x) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + logits = self.lm_head(x) + return logits \ No newline at end of file diff --git a/plm_interpretability/requirements.txt b/plm_interpretability/requirements.txt new file mode 100644 index 0000000..6da0a2e --- /dev/null +++ b/plm_interpretability/requirements.txt @@ -0,0 +1,9 @@ +tqdm +numpy +torch +fair-esm +wandb +pandas +transformers +polars +lightning \ No newline at end of file diff --git a/plm_interpretability/sae_model.py b/plm_interpretability/sae_model.py index c19895a..f1013e4 100644 --- a/plm_interpretability/sae_model.py +++ b/plm_interpretability/sae_model.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from transformers import PreTrainedModel, PreTrainedTokenizer -from plm_interpretability.utils import get_layer_activations +from utils import get_layer_activations class SparseAutoencoder(nn.Module): @@ -230,6 +230,21 @@ def get_acts(self, x: torch.Tensor) -> torch.Tensor: pre_acts = x @ self.w_enc + self.b_enc latents = self.topK_activation(pre_acts, self.k) return latents + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> torch.Tensor: + x, mu, std = self.LN(x) + x = x - self.b_pre + acts = x @ self.w_enc + self.b_enc + return acts, mu, std + + @torch.no_grad() + def decode(self, acts: torch.Tensor, mu: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + latents = self.topK_activation(acts, self.k) + + recons = latents @ self.w_dec + self.b_pre + recons = recons * std + mu + return recons def loss_fn( @@ -292,4 +307,4 @@ def estimate_loss( test_losses.append(mse_loss.item()) del layer_acts - return np.mean(test_losses) + return np.mean(test_losses) \ No newline at end of file diff --git a/plm_interpretability/sae_module.py b/plm_interpretability/sae_module.py new file mode 100644 index 0000000..284b723 --- /dev/null +++ b/plm_interpretability/sae_module.py @@ -0,0 +1,68 @@ +import esm +import torch +import pytorch_lightning as pl +from esm_wrapper import ESM2Model +from sae_model import SparseAutoencoder, loss_fn +import torch.nn.functional as F + +class SAELightningModule(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + self.args = args + self.layer_to_use = args.layer_to_use + self.sae_model = SparseAutoencoder(args.d_model, args.d_hidden) + self.alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + esm2_model = ESM2Model(num_layers=33, embed_dim=args.d_model, attention_heads=20, + alphabet=self.alphabet, token_dropout=False) + esm2_model.load_esm_ckpt(args.esm2_weight) + self.esm2_model = esm2_model + self.esm2_model.eval() + for param in self.esm2_model.parameters(): + param.requires_grad = False + + def forward(self, x): + return self.sae_model(x) + + def training_step(self, batch, batch_idx): + seqs = batch["Sequence"] + batch_size = len(seqs) + with torch.no_grad(): + tokens, esm_layer_acts = self.esm2_model.get_layer_activations(seqs, self.layer_to_use) + recons, auxk, num_dead = self(esm_layer_acts) + mse_loss, auxk_loss = loss_fn(esm_layer_acts, recons, auxk) + loss = mse_loss + auxk_loss + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) + self.log('train_mse_loss', mse_loss, on_step=True, on_epoch=True, logger=True, batch_size=batch_size) + self.log('train_auxk_loss', auxk_loss, on_step=True, on_epoch=True, logger=True, batch_size=batch_size) + self.log('num_dead_neurons', num_dead, on_step=True, on_epoch=True, logger=True, batch_size=batch_size) + return loss + + def validation_step(self, batch, batch_idx): + seqs = batch["Sequence"] + batch_size = len(seqs) + with torch.no_grad(): + tokens, esm_layer_acts = self.esm2_model.get_layer_activations(seqs, self.layer_to_use) + recons, auxk, num_dead = self(esm_layer_acts) + mse_loss, auxk_loss = loss_fn(esm_layer_acts, recons, auxk) + loss = mse_loss + auxk_loss + logits = self.esm2_model.get_sequence(recons, self.layer_to_use) + logits = logits.view(-1, logits.size(-1)) + tokens = tokens.view(-1) + correct = (torch.argmax(logits, dim=-1) == tokens).sum().item() + total = tokens.size(0) + + self.log('val_celoss', F.cross_entropy(logits, tokens).mean().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) + self.log('val_acc', correct / total, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) + return loss + + def test_step(self, batch, batch_idx): + return self.validation_step(batch, batch_idx) + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.args.lr) + + def on_after_backward(self): + self.sae_model.norm_weights() + self.sae_model.norm_grad() \ No newline at end of file diff --git a/plm_interpretability/steering.ipynb b/plm_interpretability/steering.ipynb new file mode 100644 index 0000000..53e9264 --- /dev/null +++ b/plm_interpretability/steering.ipynb @@ -0,0 +1,150 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 117, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import esm\n", + "from sae_model import SparseAutoencoder\n", + "from esm_wrapper import ESM2Model" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "metadata": {}, + "outputs": [], + "source": [ + "D_MODEL = 1280\n", + "D_HIDDEN = 4096\n", + "SEQUENCE = 'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG'\n", + "device = 'cuda:0'\n", + "weights_dir = '/global/cfs/cdirs/m4351/ml5045/interp_weights'\n", + "\n", + "esm2_weight = os.path.join(weights_dir, 'esm2_t33_650M_UR50D.pt')\n", + "sae_weight = os.path.join(weights_dir, 'sae_weights.pt')\n", + "alphabet = esm.data.Alphabet.from_architecture(\"ESM-1b\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "esm2_model = ESM2Model(num_layers=33, embed_dim=1280, attention_heads=20, \n", + " alphabet=alphabet, token_dropout=False, device='cuda:0')\n", + "esm2_model.load_esm_ckpt(esm2_weight)\n", + "esm2_model = esm2_model.to(device)\n", + "sae_model = SparseAutoencoder(D_MODEL, D_HIDDEN).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "metadata": {}, + "outputs": [], + "source": [ + "embed = esm2_model.get_layer_activations(SEQUENCE, 24)" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "metadata": {}, + "outputs": [], + "source": [ + "acts, mu, std = sae_model.encode(embed[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [], + "source": [ + "DIM = 220" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "acts.size(), mu.size(), std.size()" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [], + "source": [ + "updated_latents = sae_model.decode(acts, mu, std)" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "metadata": {}, + "outputs": [], + "source": [ + "logits = esm2_model.get_sequence(updated_latents.unsqueeze(0), 24)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "logits" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": {}, + "outputs": [], + "source": [ + "tokens = torch.argmax(logits[:,1:-1,4:24], dim=-1) \n", + "sequences = [''.join([esm2_model.alphabet.all_toks[i+4] for i in sequence.tolist()]) for sequence in list(tokens)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sequences[0]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "genie_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/plm_interpretability/training.py b/plm_interpretability/training.py new file mode 100644 index 0000000..d82c233 --- /dev/null +++ b/plm_interpretability/training.py @@ -0,0 +1,71 @@ +import os +import torch +import wandb +import argparse +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint +from sae_module import SAELightningModule +from data_module import SequenceDataModule + +os.environ["WANDB_CACHE_DIR"] = '/global/cfs/cdirs/m4351/ml5045/wandb' +parser = argparse.ArgumentParser() + +parser.add_argument('--output-dir', type=str, required=True) +parser.add_argument('--data-dir', type=str, required=True) +parser.add_argument('--esm2-weight', type=str, default='/global/cfs/cdirs/m4351/ml5045/interp_weights/esm2_t33_650M_UR50D.pt') +parser.add_argument('-l', '--layer-to_use', type=int, default=24) +parser.add_argument('--d-model', type=int, default=1280) +parser.add_argument('--d-hidden', type=int, default=32768) +parser.add_argument('-b', '--batch-size', type=int, default=64) +parser.add_argument('--lr', type=float, default=2e-3) +parser.add_argument('--k', type=int, default=128) +parser.add_argument('--auxk', type=int, default=256) +parser.add_argument('--dead-steps-threshold', type=int, default=2000) +parser.add_argument('-e', '--max-epochs', type=int, default=1) +parser.add_argument('-d', '--num-devices', type=int, default=1) + +args = parser.parse_args() + +if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + +wandb_logger = WandbLogger(project="interpretability_test", + name="sae_opt", + save_dir=os.path.join(args.output_dir, 'wandb')) + +# Initialize model and data module +model = SAELightningModule(args) +data_module = SequenceDataModule(args.data_dir, args.batch_size) + +# Set up callbacks +checkpoint_callback = ModelCheckpoint( + dirpath=os.path.join(args.output_dir, 'checkpoints'), + filename='sae-{step}-{val_loss:.2f}', + save_top_k=3, + monitor='val_loss', + mode='min', + save_last=True +) + +# Initialize Trainer +trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator='gpu', + devices=list(range(args.num_devices)), + strategy='ddp' if args.num_devices > 1 else 'auto', + logger=wandb_logger, + log_every_n_steps=10, + val_check_interval=10, + callbacks=[checkpoint_callback], + gradient_clip_val=1.0, +) + +# Train the model +trainer.fit(model, data_module) + +# Test the model +trainer.test(model, data_module) + +# Close WandB run +wandb.finish() \ No newline at end of file