Skip to content

Commit

Permalink
Updates for new models
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Jan 25, 2025
1 parent 59d5dbb commit f133121
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 16 deletions.
1 change: 1 addition & 0 deletions interprot/endpoints/sae_inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def handler(event):
logger.info(f"esm_layer_acts: {esm_layer_acts.shape}")

sae_model = sae_name_to_model[sae_name]
print(f"sae_model: {sae_model}")
sae_acts = sae_model.get_acts(esm_layer_acts)[1:-1]
logger.info(f"sae_acts: {sae_acts.shape}")

Expand Down
18 changes: 18 additions & 0 deletions interprot/endpoints/steer_feature/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ RUN wget https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt
RUN gdown https://drive.google.com/uc?id=1LtDUfcWQEwPuTdd127HEb_A2jQdRyKrU
# Weights for esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs.ckpt
RUN gdown https://drive.google.com/uc?id=19aCVCVLleTc4QSiXZsi5hPqrE21duk6q
# Weights for esm2_plm1280_l4_sae4096_k64.ckpt
RUN gdown https://drive.google.com/uc?id=1yrfhQ4Qtcpe2v9oeiBl4csklcnGbamNp
# Weights for esm2_plm1280_l8_sae4096_k64_auxk640.ckpt
RUN gdown https://drive.google.com/uc?id=1m30OvYHZmtdI8l6F1GsWr8TnZr_0Q_ax
# Weights for esm2_plm1280_l12_sae4096_k64.ckpt
RUN gdown https://drive.google.com/uc?id=1UA9Y6EV9cgY-HtNjz9n46DE4nUCJ8S1U
# Weights for esm2_plm1280_l16_sae4096_k64_auxk640.ckpt
RUN gdown https://drive.google.com/uc?id=1f_kHYqrV9qw-RKgQUBX-p5hDntwSkANd
# Weights for esm2_plm1280_l20_sae4096_k64.ckpt
RUN gdown https://drive.google.com/uc?id=1W_2sU3V4zTw0crG0fduKNdJpkk7CXgsd
# Weights for esm2_plm1280_l24_sae4096_k64_auxk640.ckpt
RUN gdown https://drive.google.com/uc?id=1QfcQLWBH5t2Bt975bbRNS33fUPGpaFJN
# Weights for esm2_plm1280_l28_sae4096_k64.ckpt
RUN gdown https://drive.google.com/uc?id=1wvyl0yb4kGbnlMYQsJpl7JSmNDLnoNpu
# Weights for esm2_plm1280_l32_sae4096_k64.ckpt
RUN gdown https://drive.google.com/uc?id=1LXwEnDsgLpyCILyTrQv_W2yTLwCV-6IP
# Weights for esm2_plm1280_l33_sae4096_aux640.ckpt
RUN gdown https://drive.google.com/uc?id=1Ly7IQjAp3UcPOgQLCgV6BQiwknV32VZU

WORKDIR /

Expand Down
10 changes: 10 additions & 0 deletions interprot/endpoints/steer_feature/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
SAE_NAME_TO_CHECKPOINT = {
"SAE4096-L24": "esm2_plm1280_l24_sae4096_100Kseqs.pt",
"SAE4096-L24-ab": "esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs.ckpt",
"SAE4096-L4": "esm2_plm1280_l4_sae4096_k64.ckpt",
"SAE4096-L8": "esm2_plm1280_l8_sae4096_k64_auxk640.ckpt",
"SAE4096-L12": "esm2_plm1280_l12_sae4096_k64.ckpt",
"SAE4096-L16": "esm2_plm1280_l16_sae4096_k64_auxk640.ckpt",
"SAE4096-L20": "esm2_plm1280_l20_sae4096_k64.ckpt",
"SAE4096-L24-v2": "esm2_plm1280_l24_sae4096_k64_auxk640.ckpt",
"SAE4096-L28": "esm2_plm1280_l28_sae4096_k64.ckpt",
"SAE4096-L32": "esm2_plm1280_l32_sae4096_k64.ckpt",
"SAE4096-L33": "esm2_plm1280_l33_sae4096_aux640.ckpt",
}


Expand Down Expand Up @@ -409,6 +418,7 @@ def handler(event):
multiplier = input_data["multiplier"]

sae_model = sae_name_to_model[sae_name]
print(f"sae_model: {sae_model}")

# First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor
_, esm_layer_acts = esm2_model.get_layer_activations(seq, 24)
Expand Down
28 changes: 14 additions & 14 deletions viz/src/SAEConfigs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 4,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L8": {
Expand All @@ -429,7 +429,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 8,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L12": {
Expand All @@ -438,7 +438,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 12,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L16": {
Expand All @@ -447,7 +447,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 16,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L20": {
Expand All @@ -456,7 +456,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 20,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L24-v2": {
Expand All @@ -466,7 +466,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L28": {
Expand All @@ -475,7 +475,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 28,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L32": {
Expand All @@ -484,7 +484,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 32,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L33": {
Expand All @@ -493,7 +493,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 4096,
plmLayer: 33,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE4096-L24-ab": {
Expand Down Expand Up @@ -666,7 +666,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE8192-L24-K32": {
Expand All @@ -675,7 +675,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE8192-L24-K64": {
Expand All @@ -684,7 +684,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE8192-L24-K128": {
Expand All @@ -693,7 +693,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
"SAE8192-L24-K256": {
Expand All @@ -702,7 +702,7 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
supportsCustomSequence: true,
curated: [],
},
};
4 changes: 2 additions & 2 deletions viz/src/runpod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export async function getSAEDimActivations(input: RunpodSAEDimActivationsInput):
}

// Both caches have missed, call API.
const data = await postRunpod(input, "jrzmm3fq54zjuy");
const data = await postRunpod(input, "fad3whi01lfmh2");
SAEDimActivationsCache[dimCacheKey] = data.tokens_acts_list;
return data.tokens_acts_list;
}
Expand All @@ -84,7 +84,7 @@ export async function getSAEAllDimsActivations(
if (cacheKey in SAEAllDimsActivationsCache) {
return SAEAllDimsActivationsCache[cacheKey];
}
const data = await postRunpod(input, "jrzmm3fq54zjuy");
const data = await postRunpod(input, "fad3whi01lfmh2");
SAEAllDimsActivationsCache[cacheKey] = data.token_acts_list_by_active_dim;
return data.token_acts_list_by_active_dim;
}
Expand Down

0 comments on commit f133121

Please sign in to comment.