Skip to content

Commit

Permalink
Enable steering for antibody SAE
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Dec 22, 2024
1 parent e65d6b3 commit 844cd99
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 50 deletions.
76 changes: 49 additions & 27 deletions interprot/endpoints/steer_feature/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
logger = logging.getLogger(__name__)

WEIGHTS_DIR = "/weights"
SAE_NAME_TO_CHECKPOINT = {
"SAE4096-L24": "esm2_plm1280_l24_sae4096_100Kseqs.pt",
"SAE4096-L24-ab": "esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs.ckpt",
}


class SparseAutoencoder(nn.Module):
Expand Down Expand Up @@ -351,43 +355,61 @@ def get_sequence(self, x, layer_idx):
return logits


# Load your model
def load_models(sae_checkpoint: str):
pattern = r"plm(\d+).*?l(\d+).*?sae(\d+)"
matches = re.search(pattern, sae_checkpoint)
if matches:
plm_dim, _, sae_dim = map(int, matches.groups())
else:
raise ValueError("Checkpoint file must be named in the format plm<n>_l<n>_sae<n>")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ESM2 model
logger.info(f"Loading ESM2 model with plm_dim={plm_dim}")
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
esm2_model = ESM2Model(
num_layers=33, embed_dim=plm_dim, attention_heads=20, alphabet=alphabet, token_dropout=False
)
esm2_weights = os.path.join(WEIGHTS_DIR, "esm2_t33_650M_UR50D.pt")
esm2_model.load_esm_ckpt(esm2_weights)
esm2_model = esm2_model.to(device)

# Load SAE model
logger.info(f"Loading SAE model with sae_dim={sae_dim}")
sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device)
sae_weights = os.path.join(WEIGHTS_DIR, sae_checkpoint)
sae_model.load_state_dict(torch.load(sae_weights))
def load_models():
sae_name_to_model = {}
for sae_name, sae_checkpoint in SAE_NAME_TO_CHECKPOINT.items():
pattern = r"plm(\d+).*?l(\d+).*?sae(\d+)"
matches = re.search(pattern, sae_checkpoint)
if matches:
plm_dim, _, sae_dim = map(int, matches.groups())
else:
raise ValueError("Checkpoint file must be named in the format plm<n>_l<n>_sae<n>")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ESM2 model
logger.info(f"Loading ESM2 model with plm_dim={plm_dim}")
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
esm2_model = ESM2Model(
num_layers=33,
embed_dim=plm_dim,
attention_heads=20,
alphabet=alphabet,
token_dropout=False,
)
esm2_weights = os.path.join(WEIGHTS_DIR, "esm2_t33_650M_UR50D.pt")
esm2_model.load_esm_ckpt(esm2_weights)
esm2_model = esm2_model.to(device)

# Load SAE models
logger.info(f"Loading SAE model {sae_name}")
sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device)
sae_weights = os.path.join(WEIGHTS_DIR, sae_checkpoint)
# Support different checkpoint formats
try:
sae_model.load_state_dict(torch.load(sae_weights))
except Exception:
sae_model.load_state_dict(
{
k.replace("sae_model.", ""): v
for k, v in torch.load(sae_weights)["state_dict"].items()
}
)
sae_name_to_model[sae_name] = sae_model

logger.info("Models loaded successfully")
return esm2_model, sae_model
return esm2_model, sae_name_to_model


def handler(event):
try:
input_data = event["input"]
seq = input_data["sequence"]
sae_name = input_data["sae_name"]
dim = input_data["dim"]
multiplier = input_data["multiplier"]

sae_model = sae_name_to_model[sae_name]

# First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor
_, esm_layer_acts = esm2_model.get_layer_activations(seq, 24)
sae_latents, mu, std = sae_model.encode(esm_layer_acts[0])
Expand Down Expand Up @@ -420,5 +442,5 @@ def handler(event):
return {"status": "error", "error": str(e)}


esm2_model, sae_model = load_models("esm2_plm1280_l24_sae4096_100Kseqs.pt")
esm2_model, sae_name_to_model = load_models()
runpod.serverless.start({"handler": handler})
4 changes: 1 addition & 3 deletions viz/src/SAEVisualizerPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,7 @@ const SAEVisualizerPage: React.FC = () => {
) : (
<div className="mt-3">
<Markdown>{descStr}</Markdown>
{SAEConfig?.supportsCustomSequence && (
<CustomSeqPlayground feature={feature} saeName={model} />
)}
{SAEConfig?.supportsCustomSequence && <CustomSeqPlayground />}
{isLoading ? (
<div className="flex items-center justify-center w-full mt-5">
<img
Expand Down
36 changes: 18 additions & 18 deletions viz/src/components/CustomSeqPlayground.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useState, useEffect, useCallback } from "react";
import { useState, useEffect, useCallback, useContext } from "react";
import { Button } from "@/components/ui/button";
import { Slider } from "@/components/ui/slider";
import {
Expand All @@ -16,11 +16,7 @@ import useUrlState from "@/hooks/useUrlState";
import FullSeqsViewer from "./FullSeqsViewer";
import PDBStructureViewer from "./PDBStructureViewer";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";

interface CustomSeqPlaygroundProps {
feature: number;
saeName: string;
}
import { SAEContext } from "@/SAEContext";

enum PlaygroundState {
IDLE,
Expand All @@ -37,7 +33,7 @@ const initialState = {
steeredActivations: [] as number[],
} as const;

const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) => {
const CustomSeqPlayground = () => {
const [inputProteinActivations, setInputProteinActivations] = useState<ProteinActivationsData>(
initialState.inputProteinActivations
);
Expand All @@ -49,11 +45,12 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) =>
const [steeredActivations, setSteeredActivations] = useState<number[]>(
initialState.steeredActivations
);
const { feature, model } = useContext(SAEContext);

const [urlState, setUrlState] = useUrlState<{ pdb?: string; seq?: string }>();

const handleSubmit = useCallback(
async (submittedInput: ValidSeqInput) => {
async (submittedInput: ValidSeqInput, feature: number, model: string) => {
setViewerState(PlaygroundState.LOADING_SAE_ACTIVATIONS);

setInputProteinActivations(initialState.inputProteinActivations);
Expand All @@ -63,15 +60,15 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) =>

if (isPDBID(submittedInput)) {
setInputProteinActivations(
await constructProteinActivationsDataFromPDBID(submittedInput, feature, saeName)
await constructProteinActivationsDataFromPDBID(submittedInput, feature, model)
);
} else {
setInputProteinActivations(
await constructProteinActivationsDataFromSequence(submittedInput, feature, saeName)
await constructProteinActivationsDataFromSequence(submittedInput, feature, model)
);
}
},
[feature, saeName]
[]
);

// Reset some states when the user navigates to a new feature
Expand All @@ -85,35 +82,38 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) =>

// If an input is set in the URL, submit it
useEffect(() => {
if (feature === undefined) return;
if (urlState.pdb && isPDBID(urlState.pdb)) {
setCustomSeqInput(urlState.pdb);
handleSubmit(urlState.pdb);
handleSubmit(urlState.pdb, feature, model);
} else if (urlState.seq && isProteinSequence(urlState.seq)) {
setCustomSeqInput(urlState.seq);
handleSubmit(urlState.seq);
handleSubmit(urlState.seq, feature, model);
}
}, [urlState.pdb, urlState.seq, setCustomSeqInput, handleSubmit]);
}, [urlState.pdb, urlState.seq, setCustomSeqInput, handleSubmit, feature, model]);

const handleSteer = async () => {
const handleSteer = async (input: string, feature: number, model: string) => {
setViewerState(PlaygroundState.LOADING_STEERED_SEQUENCE);

// Reset some states related to downstream actions
setSteeredActivations(initialState.steeredActivations);
setSteeredSeq(initialState.steeredSeq);

const steeredSeq = await getSteeredSequence({
sequence: proteinInput,
sequence: input,
dim: feature,
multiplier: steerMultiplier,
sae_name: model,
});
setSteeredSeq(steeredSeq);
setSteeredActivations(
await getSAEDimActivations({ sequence: steeredSeq, dim: feature, sae_name: saeName })
await getSAEDimActivations({ sequence: steeredSeq, dim: feature, sae_name: model })
);
};

const onStructureLoad = useCallback(() => setViewerState(PlaygroundState.IDLE), []);

if (feature === undefined) return null;
return (
<div className="mb-6">
<div className="mt-5">
Expand Down Expand Up @@ -251,7 +251,7 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) =>

{/* Steer button */}
<Button
onClick={handleSteer}
onClick={() => handleSteer(proteinInput, feature, model)}
disabled={playgroundState === PlaygroundState.LOADING_STEERED_SEQUENCE}
className="w-full sm:w-auto min-w-24"
>
Expand Down
3 changes: 1 addition & 2 deletions viz/src/runpod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ type RunpodSAEAllDimsActivationsInput = {
};

type RunpodSteeringInput = {
// TODO: support different SAE models
// sae_name: string
sae_name: string;
sequence: string;
dim: number;
multiplier: number;
Expand Down

0 comments on commit 844cd99

Please sign in to comment.