diff --git a/interprot/endpoints/sae_inference/handler.py b/interprot/endpoints/sae_inference/handler.py index 5d44722..cd6ab46 100644 --- a/interprot/endpoints/sae_inference/handler.py +++ b/interprot/endpoints/sae_inference/handler.py @@ -421,6 +421,7 @@ def handler(event): logger.info(f"esm_layer_acts: {esm_layer_acts.shape}") sae_model = sae_name_to_model[sae_name] + print(f"sae_model: {sae_model}") sae_acts = sae_model.get_acts(esm_layer_acts)[1:-1] logger.info(f"sae_acts: {sae_acts.shape}") diff --git a/interprot/endpoints/steer_feature/Dockerfile b/interprot/endpoints/steer_feature/Dockerfile index e0e1444..184c8ec 100644 --- a/interprot/endpoints/steer_feature/Dockerfile +++ b/interprot/endpoints/steer_feature/Dockerfile @@ -17,6 +17,24 @@ RUN wget https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt RUN gdown https://drive.google.com/uc?id=1LtDUfcWQEwPuTdd127HEb_A2jQdRyKrU # Weights for esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs.ckpt RUN gdown https://drive.google.com/uc?id=19aCVCVLleTc4QSiXZsi5hPqrE21duk6q +# Weights for esm2_plm1280_l4_sae4096_k64.ckpt +RUN gdown https://drive.google.com/uc?id=1yrfhQ4Qtcpe2v9oeiBl4csklcnGbamNp +# Weights for esm2_plm1280_l8_sae4096_k64_auxk640.ckpt +RUN gdown https://drive.google.com/uc?id=1m30OvYHZmtdI8l6F1GsWr8TnZr_0Q_ax +# Weights for esm2_plm1280_l12_sae4096_k64.ckpt +RUN gdown https://drive.google.com/uc?id=1UA9Y6EV9cgY-HtNjz9n46DE4nUCJ8S1U +# Weights for esm2_plm1280_l16_sae4096_k64_auxk640.ckpt +RUN gdown https://drive.google.com/uc?id=1f_kHYqrV9qw-RKgQUBX-p5hDntwSkANd +# Weights for esm2_plm1280_l20_sae4096_k64.ckpt +RUN gdown https://drive.google.com/uc?id=1W_2sU3V4zTw0crG0fduKNdJpkk7CXgsd +# Weights for esm2_plm1280_l24_sae4096_k64_auxk640.ckpt +RUN gdown https://drive.google.com/uc?id=1QfcQLWBH5t2Bt975bbRNS33fUPGpaFJN +# Weights for esm2_plm1280_l28_sae4096_k64.ckpt +RUN gdown https://drive.google.com/uc?id=1wvyl0yb4kGbnlMYQsJpl7JSmNDLnoNpu +# Weights for esm2_plm1280_l32_sae4096_k64.ckpt +RUN gdown https://drive.google.com/uc?id=1LXwEnDsgLpyCILyTrQv_W2yTLwCV-6IP +# Weights for esm2_plm1280_l33_sae4096_aux640.ckpt +RUN gdown https://drive.google.com/uc?id=1Ly7IQjAp3UcPOgQLCgV6BQiwknV32VZU WORKDIR / diff --git a/interprot/endpoints/steer_feature/handler.py b/interprot/endpoints/steer_feature/handler.py index 71eb387..221ffab 100644 --- a/interprot/endpoints/steer_feature/handler.py +++ b/interprot/endpoints/steer_feature/handler.py @@ -25,6 +25,15 @@ SAE_NAME_TO_CHECKPOINT = { "SAE4096-L24": "esm2_plm1280_l24_sae4096_100Kseqs.pt", "SAE4096-L24-ab": "esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs.ckpt", + "SAE4096-L4": "esm2_plm1280_l4_sae4096_k64.ckpt", + "SAE4096-L8": "esm2_plm1280_l8_sae4096_k64_auxk640.ckpt", + "SAE4096-L12": "esm2_plm1280_l12_sae4096_k64.ckpt", + "SAE4096-L16": "esm2_plm1280_l16_sae4096_k64_auxk640.ckpt", + "SAE4096-L20": "esm2_plm1280_l20_sae4096_k64.ckpt", + "SAE4096-L24-v2": "esm2_plm1280_l24_sae4096_k64_auxk640.ckpt", + "SAE4096-L28": "esm2_plm1280_l28_sae4096_k64.ckpt", + "SAE4096-L32": "esm2_plm1280_l32_sae4096_k64.ckpt", + "SAE4096-L33": "esm2_plm1280_l33_sae4096_aux640.ckpt", } @@ -409,6 +418,7 @@ def handler(event): multiplier = input_data["multiplier"] sae_model = sae_name_to_model[sae_name] + print(f"sae_model: {sae_model}") # First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor _, esm_layer_acts = esm2_model.get_layer_activations(seq, 24) diff --git a/viz/src/SAEConfigs.ts b/viz/src/SAEConfigs.ts index 2dc4d3e..d1cc47b 100644 --- a/viz/src/SAEConfigs.ts +++ b/viz/src/SAEConfigs.ts @@ -420,7 +420,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 4, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L8": { @@ -429,7 +429,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 8, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L12": { @@ -438,7 +438,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 12, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L16": { @@ -447,7 +447,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 16, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L20": { @@ -456,7 +456,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 20, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L24-v2": { @@ -466,7 +466,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 24, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L28": { @@ -475,7 +475,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 28, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L32": { @@ -484,7 +484,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 32, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L33": { @@ -493,7 +493,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 4096, plmLayer: 33, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE4096-L24-ab": { @@ -666,7 +666,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 8192, plmLayer: 24, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE8192-L24-K32": { @@ -675,7 +675,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 8192, plmLayer: 24, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE8192-L24-K64": { @@ -684,7 +684,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 8192, plmLayer: 24, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE8192-L24-K128": { @@ -693,7 +693,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 8192, plmLayer: 24, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, "SAE8192-L24-K256": { @@ -702,7 +702,7 @@ export const SAE_CONFIGS: Record = { numHiddenDims: 8192, plmLayer: 24, defaultDim: 0, - supportsCustomSequence: false, + supportsCustomSequence: true, curated: [], }, }; diff --git a/viz/src/runpod.ts b/viz/src/runpod.ts index c389ad3..f6cbbbc 100644 --- a/viz/src/runpod.ts +++ b/viz/src/runpod.ts @@ -72,7 +72,7 @@ export async function getSAEDimActivations(input: RunpodSAEDimActivationsInput): } // Both caches have missed, call API. - const data = await postRunpod(input, "jrzmm3fq54zjuy"); + const data = await postRunpod(input, "fad3whi01lfmh2"); SAEDimActivationsCache[dimCacheKey] = data.tokens_acts_list; return data.tokens_acts_list; } @@ -84,7 +84,7 @@ export async function getSAEAllDimsActivations( if (cacheKey in SAEAllDimsActivationsCache) { return SAEAllDimsActivationsCache[cacheKey]; } - const data = await postRunpod(input, "jrzmm3fq54zjuy"); + const data = await postRunpod(input, "fad3whi01lfmh2"); SAEAllDimsActivationsCache[cacheKey] = data.token_acts_list_by_active_dim; return data.token_acts_list_by_active_dim; }