Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
etowahadams committed Jan 15, 2025
2 parents badfa21 + d28f4dd commit 0c02561
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 13 deletions.
7 changes: 4 additions & 3 deletions interprot/make_viz_files/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
from pathlib import Path
from typing import Any

import click
import numpy as np
Expand Down Expand Up @@ -43,7 +44,7 @@ def get_esm_layer_acts(
"--output-dir",
type=click.Path(exists=True, dir_okay=True),
required=True,
help="Path to the sequences file containing AlphaFoldDB IDs",
help="Path to the output directory in which the JSON files will be written",
)
def make_viz_files(checkpoint_files: list[str], sequences_file: str, output_dir: Path):
"""
Expand Down Expand Up @@ -157,7 +158,7 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str, output_dir:


def write_viz_file(dim_info, dim, all_acts, df, range_names, output_dir: Path):
viz_file = {"ranges": {}}
viz_file: dict[str, Any] = {"ranges": {}}
# Write how common the dimension is
if "freq_active" in dim_info:
viz_file["freq_active"] = dim_info["freq_active"]
Expand All @@ -171,7 +172,7 @@ def write_viz_file(dim_info, dim, all_acts, df, range_names, output_dir: Path):
for range_name in range_names:
if range_name not in dim_info:
continue
range_examples = {
range_examples: dict[str, list] = {
"examples": [],
}
top_indices = dim_info[range_name]["indices"]
Expand Down
65 changes: 60 additions & 5 deletions viz/src/SAEConfigs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export type CuratedFeature = {
};

export type SAEConfig = {
baseUrl: string;
storagePath: string;
description: string;
numHiddenDims: number;
plmLayer: number;
Expand All @@ -25,10 +25,12 @@ export const CONTRIBUTORS: Record<string, string> = {
"James Michael Krieger": "http://github.com/jamesmkrieger",
};

export const STORAGE_ROOT_URL =
"https://raw.githubusercontent.com/liambai/plm-interp-viz-data/refs/heads/main";

export const SAE_CONFIGS: Record<string, SAEConfig> = {
"SAE4096-L24": {
baseUrl:
"https://raw.githubusercontent.com/liambai/plm-interp-viz-data/refs/heads/main/esm2_plm1280_l24_sae4096_100Kseqs/",
storagePath: "esm2_plm1280_l24_sae4096_100Kseqs",
description:
"This SAE was trained on layer 24 of [ESM2-650M](https://huggingface.co/facebook/esm2_t33_650M_UR50D) using sequences from [UniRef50](https://www.uniprot.org/help/uniref) and has 4096 hidden dimensions. Click on a feature below to visualize its activation pattern.",
numHiddenDims: 4096,
Expand Down Expand Up @@ -412,9 +414,17 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
defaultDim: 4000,
supportsCustomSequence: true,
},
"SAE4096-L33": {
storagePath: "esm2_plm1280_l33_sae4096_aux640",
description: "",
numHiddenDims: 4096,
plmLayer: 33,
defaultDim: 0,
supportsCustomSequence: false,
curated: [],
},
"SAE4096-L24-ab": {
baseUrl:
"https://raw.githubusercontent.com/liambai/plm-interp-viz-data/refs/heads/main/esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs/",
storagePath: "esm2_plm1280_l24_sae4096_k128_auxk512_antibody_seqs",
description:
"This SAE was trained on layer 24 of [ESM2-650M](https://huggingface.co/facebook/esm2_t33_650M_UR50D) using antibody sequences from [PLAbDab](https://opig.stats.ox.ac.uk/webapps/plabdab/) and has 4096 hidden dimensions. Click on a feature below to visualize its activation pattern.",
numHiddenDims: 4096,
Expand Down Expand Up @@ -577,4 +587,49 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
},
],
},
"SAE8192-L24-K16": {
storagePath: "k_sweep/esm2_plm1280_l24_sae8192_k16_auxk640",
description: "",
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
curated: [],
},
"SAE8192-L24-K32": {
storagePath: "k_sweep/esm2_plm1280_l24_sae8192_k32_auxk640",
description: "",
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
curated: [],
},
"SAE8192-L24-K64": {
storagePath: "k_sweep/esm2_plm1280_l24_sae8192_k64_auxk640",
description: "",
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
curated: [],
},
"SAE8192-L24-K128": {
storagePath: "k_sweep/esm2_plm1280_l24_sae8192_k128_auxk640",
description: "",
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
curated: [],
},
"SAE8192-L24-K256": {
storagePath: "k_sweep/esm2_plm1280_l24_sae8192_k256_auxk640",
description: "",
numHiddenDims: 8192,
plmLayer: 24,
defaultDim: 0,
supportsCustomSequence: false,
curated: [],
},
};
19 changes: 16 additions & 3 deletions viz/src/SAEVisualizerPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import proteinEmoji from "./protein.png";

import { SAEContext } from "./SAEContext";
import { NUM_SEQS_TO_DISPLAY } from "./config";
import { CONTRIBUTORS } from "./SAEConfigs";
import { CONTRIBUTORS, STORAGE_ROOT_URL } from "./SAEConfigs";
import SeqsViewer, { SeqWithSAEActs } from "./components/SeqsViewer";
import {
Accordion,
Expand All @@ -16,6 +16,7 @@ import {
AccordionTrigger,
} from "@/components/ui/accordion";
import Markdown from "@/components/Markdown";
import { Info } from "lucide-react";

const actRanges: [number, number][] = [
[0.75, 1],
Expand Down Expand Up @@ -80,7 +81,7 @@ const SAEVisualizerPage: React.FC = () => {
const [isDeadLatent, setIsDeadLatent] = useState(false);

useEffect(() => {
const fileURL = `${SAEConfig.baseUrl}${feature}.json`;
const fileURL = `${STORAGE_ROOT_URL}/${SAEConfig.storagePath}/${feature}.json`;

const loadData = async () => {
setFeatureStats(undefined);
Expand Down Expand Up @@ -145,7 +146,19 @@ const SAEVisualizerPage: React.FC = () => {
) : (
<div className="mt-3">
<Markdown>{descStr}</Markdown>
{SAEConfig?.supportsCustomSequence && <CustomSeqPlayground />}
{SAEConfig?.supportsCustomSequence ? (
<CustomSeqPlayground />
) : (
<div className="mb-10 mt-10 p-4 border-2 border-gray-200 rounded-lg bg-gray-50 flex items-center gap-4">
<div className="text-amber-600">
<Info className="h-6 w-6" />
</div>
<p className="text-gray-700">
Some of our SAEs support custom sequence inputs. This one currently does not. Try
a different model in the model dropdown to search and steer your own sequence.
</p>
</div>
)}
{isLoading ? (
<div className="flex items-center justify-center w-full mt-5">
<img
Expand Down
8 changes: 7 additions & 1 deletion viz/src/components/SAESidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ export default function SAESidebar() {
<Select
value={model}
onValueChange={(value) =>
navigate(`/sae-viz/${value}${feature !== undefined ? `/${feature}` : ""}`)
navigate(
`/sae-viz/${value}${
SAE_CONFIGS[value].defaultDim !== undefined
? `/${SAE_CONFIGS[value].defaultDim}`
: ""
}`
)
}
>
<SelectTrigger className="mb-3">
Expand Down
2 changes: 1 addition & 1 deletion viz/src/components/SeqsViewer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ export default function SeqsViewer({ seqs, title }: SeqsViewerProps) {
<br />
<span className="font-semibold">3Di:</span>{" "}
<Markdown>
Structural tokens describing geometric conformation, invented for
Structural tokens describing geometric conformation from
[Foldseek](https://www.nature.com/articles/s41587-023-01773-0)
</Markdown>
</p>
Expand Down

0 comments on commit 0c02561

Please sign in to comment.