From 844cd99f58cca2fc4964075a0fe7c49e10b10bf9 Mon Sep 17 00:00:00 2001 From: Liam Date: Sun, 22 Dec 2024 14:02:34 +0800 Subject: [PATCH] Enable steering for antibody SAE --- interprot/endpoints/steer_feature/handler.py | 76 +++++++++++++------- viz/src/SAEVisualizerPage.tsx | 4 +- viz/src/components/CustomSeqPlayground.tsx | 36 +++++----- viz/src/runpod.ts | 3 +- 4 files changed, 69 insertions(+), 50 deletions(-) diff --git a/interprot/endpoints/steer_feature/handler.py b/interprot/endpoints/steer_feature/handler.py index d517180..71eb387 100644 --- a/interprot/endpoints/steer_feature/handler.py +++ b/interprot/endpoints/steer_feature/handler.py @@ -22,6 +22,10 @@ logger = logging.getLogger(__name__) WEIGHTS_DIR = "/weights" +SAE_NAME_TO_CHECKPOINT = { + "SAE4096-L24": "esm2_plm1280_l24_sae4096_100Kseqs.pt", + "SAE4096-L24-ab": "esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs.ckpt", +} class SparseAutoencoder(nn.Module): @@ -351,43 +355,61 @@ def get_sequence(self, x, layer_idx): return logits -# Load your model -def load_models(sae_checkpoint: str): - pattern = r"plm(\d+).*?l(\d+).*?sae(\d+)" - matches = re.search(pattern, sae_checkpoint) - if matches: - plm_dim, _, sae_dim = map(int, matches.groups()) - else: - raise ValueError("Checkpoint file must be named in the format plm_l_sae") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Load ESM2 model - logger.info(f"Loading ESM2 model with plm_dim={plm_dim}") - alphabet = esm.data.Alphabet.from_architecture("ESM-1b") - esm2_model = ESM2Model( - num_layers=33, embed_dim=plm_dim, attention_heads=20, alphabet=alphabet, token_dropout=False - ) - esm2_weights = os.path.join(WEIGHTS_DIR, "esm2_t33_650M_UR50D.pt") - esm2_model.load_esm_ckpt(esm2_weights) - esm2_model = esm2_model.to(device) - - # Load SAE model - logger.info(f"Loading SAE model with sae_dim={sae_dim}") - sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device) - sae_weights = os.path.join(WEIGHTS_DIR, sae_checkpoint) - sae_model.load_state_dict(torch.load(sae_weights)) +def load_models(): + sae_name_to_model = {} + for sae_name, sae_checkpoint in SAE_NAME_TO_CHECKPOINT.items(): + pattern = r"plm(\d+).*?l(\d+).*?sae(\d+)" + matches = re.search(pattern, sae_checkpoint) + if matches: + plm_dim, _, sae_dim = map(int, matches.groups()) + else: + raise ValueError("Checkpoint file must be named in the format plm_l_sae") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load ESM2 model + logger.info(f"Loading ESM2 model with plm_dim={plm_dim}") + alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + esm2_model = ESM2Model( + num_layers=33, + embed_dim=plm_dim, + attention_heads=20, + alphabet=alphabet, + token_dropout=False, + ) + esm2_weights = os.path.join(WEIGHTS_DIR, "esm2_t33_650M_UR50D.pt") + esm2_model.load_esm_ckpt(esm2_weights) + esm2_model = esm2_model.to(device) + + # Load SAE models + logger.info(f"Loading SAE model {sae_name}") + sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device) + sae_weights = os.path.join(WEIGHTS_DIR, sae_checkpoint) + # Support different checkpoint formats + try: + sae_model.load_state_dict(torch.load(sae_weights)) + except Exception: + sae_model.load_state_dict( + { + k.replace("sae_model.", ""): v + for k, v in torch.load(sae_weights)["state_dict"].items() + } + ) + sae_name_to_model[sae_name] = sae_model logger.info("Models loaded successfully") - return esm2_model, sae_model + return esm2_model, sae_name_to_model def handler(event): try: input_data = event["input"] seq = input_data["sequence"] + sae_name = input_data["sae_name"] dim = input_data["dim"] multiplier = input_data["multiplier"] + sae_model = sae_name_to_model[sae_name] + # First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor _, esm_layer_acts = esm2_model.get_layer_activations(seq, 24) sae_latents, mu, std = sae_model.encode(esm_layer_acts[0]) @@ -420,5 +442,5 @@ def handler(event): return {"status": "error", "error": str(e)} -esm2_model, sae_model = load_models("esm2_plm1280_l24_sae4096_100Kseqs.pt") +esm2_model, sae_name_to_model = load_models() runpod.serverless.start({"handler": handler}) diff --git a/viz/src/SAEVisualizerPage.tsx b/viz/src/SAEVisualizerPage.tsx index 16a759f..cdd3497 100644 --- a/viz/src/SAEVisualizerPage.tsx +++ b/viz/src/SAEVisualizerPage.tsx @@ -145,9 +145,7 @@ const SAEVisualizerPage: React.FC = () => { ) : (
{descStr} - {SAEConfig?.supportsCustomSequence && ( - - )} + {SAEConfig?.supportsCustomSequence && } {isLoading ? (
{ +const CustomSeqPlayground = () => { const [inputProteinActivations, setInputProteinActivations] = useState( initialState.inputProteinActivations ); @@ -49,11 +45,12 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) => const [steeredActivations, setSteeredActivations] = useState( initialState.steeredActivations ); + const { feature, model } = useContext(SAEContext); const [urlState, setUrlState] = useUrlState<{ pdb?: string; seq?: string }>(); const handleSubmit = useCallback( - async (submittedInput: ValidSeqInput) => { + async (submittedInput: ValidSeqInput, feature: number, model: string) => { setViewerState(PlaygroundState.LOADING_SAE_ACTIVATIONS); setInputProteinActivations(initialState.inputProteinActivations); @@ -63,15 +60,15 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) => if (isPDBID(submittedInput)) { setInputProteinActivations( - await constructProteinActivationsDataFromPDBID(submittedInput, feature, saeName) + await constructProteinActivationsDataFromPDBID(submittedInput, feature, model) ); } else { setInputProteinActivations( - await constructProteinActivationsDataFromSequence(submittedInput, feature, saeName) + await constructProteinActivationsDataFromSequence(submittedInput, feature, model) ); } }, - [feature, saeName] + [] ); // Reset some states when the user navigates to a new feature @@ -85,16 +82,17 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) => // If an input is set in the URL, submit it useEffect(() => { + if (feature === undefined) return; if (urlState.pdb && isPDBID(urlState.pdb)) { setCustomSeqInput(urlState.pdb); - handleSubmit(urlState.pdb); + handleSubmit(urlState.pdb, feature, model); } else if (urlState.seq && isProteinSequence(urlState.seq)) { setCustomSeqInput(urlState.seq); - handleSubmit(urlState.seq); + handleSubmit(urlState.seq, feature, model); } - }, [urlState.pdb, urlState.seq, setCustomSeqInput, handleSubmit]); + }, [urlState.pdb, urlState.seq, setCustomSeqInput, handleSubmit, feature, model]); - const handleSteer = async () => { + const handleSteer = async (input: string, feature: number, model: string) => { setViewerState(PlaygroundState.LOADING_STEERED_SEQUENCE); // Reset some states related to downstream actions @@ -102,18 +100,20 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) => setSteeredSeq(initialState.steeredSeq); const steeredSeq = await getSteeredSequence({ - sequence: proteinInput, + sequence: input, dim: feature, multiplier: steerMultiplier, + sae_name: model, }); setSteeredSeq(steeredSeq); setSteeredActivations( - await getSAEDimActivations({ sequence: steeredSeq, dim: feature, sae_name: saeName }) + await getSAEDimActivations({ sequence: steeredSeq, dim: feature, sae_name: model }) ); }; const onStructureLoad = useCallback(() => setViewerState(PlaygroundState.IDLE), []); + if (feature === undefined) return null; return (
@@ -251,7 +251,7 @@ const CustomSeqPlayground = ({ feature, saeName }: CustomSeqPlaygroundProps) => {/* Steer button */}