Skip to content

Commit

Permalink
training with esm val metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
haewonc committed Oct 15, 2024
1 parent e2692d1 commit db0d1ec
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 2 deletions.
36 changes: 36 additions & 0 deletions plm_interpretability/data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import polars as pr
import pytorch_lightning as pl
from torch.utils.data import Dataset
from utils import train_val_test_split

class PolarsDataset(Dataset):
def __init__(self, df):
self.df = df

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
row = self.df.row(idx, named=True)
return {"Sequence": row["Sequence"], "Entry": row["Entry"]}

# Data Module
class SequenceDataModule(pl.LightningDataModule):
def __init__(self, data_path, batch_size):
super().__init__()
self.data_path = data_path
self.batch_size = batch_size

def setup(self, stage=None):
df = pr.read_parquet(self.data_path)
self.train_data, self.val_data, self.test_data = train_val_test_split(df)

def train_dataloader(self):
return torch.utils.data.DataLoader(PolarsDataset(self.train_data), batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return torch.utils.data.DataLoader(PolarsDataset(self.val_data), batch_size=self.batch_size)

def test_dataloader(self):
return torch.utils.data.DataLoader(PolarsDataset(self.test_data), batch_size=self.batch_size)
98 changes: 98 additions & 0 deletions plm_interpretability/esm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
from esm.modules import ESM1bLayerNorm, RobertaLMHead, TransformerLayer

class ESM2Model(pl.LightningModule):
def __init__(self, num_layers, embed_dim, attention_heads, alphabet, token_dropout):
super().__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.attention_heads = attention_heads
self.alphabet = alphabet
self.alphabet_size = len(alphabet)
self.batch_converter = self.alphabet.get_batch_converter()
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.token_dropout = token_dropout
self._init_submodules()

def _init_submodules(self):
self.embed_scale = 1
self.embed_tokens = nn.Embedding(
self.alphabet_size,
self.embed_dim,
padding_idx=self.padding_idx,
)

self.layers = nn.ModuleList(
[
TransformerLayer(
self.embed_dim,
4 * self.embed_dim,
self.attention_heads,
add_bias_kv=False,
use_esm1b_layer_norm=True,
use_rotary_embeddings=True,
)
for _ in range(self.num_layers)
]
)

self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)

self.lm_head = RobertaLMHead(
embed_dim=self.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)

def load_esm_ckpt(self, esm_pretrained):
ckpt = {}
model_data = torch.load(esm_pretrained)["model"]
for k in model_data:
if 'lm_head' in k:
ckpt[k.replace('encoder.','')] = model_data[k]
else:
ckpt[k.replace('encoder.sentence_encoder.','')] = model_data[k]
self.load_state_dict(ckpt)

def compose_input(self, list_tuple_seq):
_, _, batch_tokens = self.batch_converter(list_tuple_seq)
batch_tokens = batch_tokens.to(self.device)
return batch_tokens

def get_layer_activations(self, input, layer_idx):
if isinstance(input, str):
tokens = self.compose_input([('protein', input)])
elif isinstance(input, list):
tokens = self.compose_input([('protein', seq) for seq in input])
else:
tokens = input

x = self.embed_scale * self.embed_tokens(tokens)
x = x.transpose(0, 1) # (B, T, E) => (T, B, E)
for _, layer in enumerate(self.layers[:layer_idx]):
x, attn = layer(
x,
self_attn_padding_mask=None,
need_head_weights=False,
)
return tokens, x.transpose(0, 1)

def get_sequence(self, x, layer_idx):
x = x.transpose(0, 1) # (B, T, E) => (T, B, E)
for _, layer in enumerate(self.layers[layer_idx:]):
x, attn = layer(
x,
self_attn_padding_mask=None,
need_head_weights=False,
)
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
logits = self.lm_head(x)
return logits
9 changes: 9 additions & 0 deletions plm_interpretability/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
tqdm
numpy
torch
fair-esm
wandb
pandas
transformers
polars
lightning
19 changes: 17 additions & 2 deletions plm_interpretability/sae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn import functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer

from plm_interpretability.utils import get_layer_activations
from utils import get_layer_activations


class SparseAutoencoder(nn.Module):
Expand Down Expand Up @@ -230,6 +230,21 @@ def get_acts(self, x: torch.Tensor) -> torch.Tensor:
pre_acts = x @ self.w_enc + self.b_enc
latents = self.topK_activation(pre_acts, self.k)
return latents

@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
x, mu, std = self.LN(x)
x = x - self.b_pre
acts = x @ self.w_enc + self.b_enc
return acts, mu, std

@torch.no_grad()
def decode(self, acts: torch.Tensor, mu: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
latents = self.topK_activation(acts, self.k)

recons = latents @ self.w_dec + self.b_pre
recons = recons * std + mu
return recons


def loss_fn(
Expand Down Expand Up @@ -292,4 +307,4 @@ def estimate_loss(
test_losses.append(mse_loss.item())

del layer_acts
return np.mean(test_losses)
return np.mean(test_losses)
68 changes: 68 additions & 0 deletions plm_interpretability/sae_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import esm
import torch
import pytorch_lightning as pl
from esm_wrapper import ESM2Model
from sae_model import SparseAutoencoder, loss_fn
import torch.nn.functional as F

class SAELightningModule(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.args = args
self.layer_to_use = args.layer_to_use
self.sae_model = SparseAutoencoder(args.d_model, args.d_hidden)
self.alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
esm2_model = ESM2Model(num_layers=33, embed_dim=args.d_model, attention_heads=20,
alphabet=self.alphabet, token_dropout=False)
esm2_model.load_esm_ckpt(args.esm2_weight)
self.esm2_model = esm2_model
self.esm2_model.eval()
for param in self.esm2_model.parameters():
param.requires_grad = False

def forward(self, x):
return self.sae_model(x)

def training_step(self, batch, batch_idx):
seqs = batch["Sequence"]
batch_size = len(seqs)
with torch.no_grad():
tokens, esm_layer_acts = self.esm2_model.get_layer_activations(seqs, self.layer_to_use)
recons, auxk, num_dead = self(esm_layer_acts)
mse_loss, auxk_loss = loss_fn(esm_layer_acts, recons, auxk)
loss = mse_loss + auxk_loss
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
self.log('train_mse_loss', mse_loss, on_step=True, on_epoch=True, logger=True, batch_size=batch_size)
self.log('train_auxk_loss', auxk_loss, on_step=True, on_epoch=True, logger=True, batch_size=batch_size)
self.log('num_dead_neurons', num_dead, on_step=True, on_epoch=True, logger=True, batch_size=batch_size)
return loss

def validation_step(self, batch, batch_idx):
seqs = batch["Sequence"]
batch_size = len(seqs)
with torch.no_grad():
tokens, esm_layer_acts = self.esm2_model.get_layer_activations(seqs, self.layer_to_use)
recons, auxk, num_dead = self(esm_layer_acts)
mse_loss, auxk_loss = loss_fn(esm_layer_acts, recons, auxk)
loss = mse_loss + auxk_loss
logits = self.esm2_model.get_sequence(recons, self.layer_to_use)
logits = logits.view(-1, logits.size(-1))
tokens = tokens.view(-1)
correct = (torch.argmax(logits, dim=-1) == tokens).sum().item()
total = tokens.size(0)

self.log('val_celoss', F.cross_entropy(logits, tokens).mean().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
self.log('val_acc', correct / total, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
return loss

def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)

def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.args.lr)

def on_after_backward(self):
self.sae_model.norm_weights()
self.sae_model.norm_grad()
Loading

0 comments on commit db0d1ec

Please sign in to comment.