Skip to content

Commit

Permalink
Use new visualization file format (#30)
Browse files Browse the repository at this point in the history
* extract basic features from file

* add title prop

* use new format

* simplify logic
  • Loading branch information
etowahadams authored Nov 26, 2024
1 parent 3cc924d commit 2e21bdc
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 30 deletions.
2 changes: 1 addition & 1 deletion viz/src/SAEConfigs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export const CONTRIBUTORS: Record<string, string> = {
export const SAE_CONFIGS: Record<string, SAEConfig> = {
"SAE4096-L24": {
baseUrl:
"https://raw.githubusercontent.com/liambai/plm-interp-viz-data/refs/heads/main/esm2_plm1280_l24_sae4096_100Kseqs/",
"https://raw.githubusercontent.com/liambai/plm-interp-viz-data/refs/heads/main/esm2_plm1280_l24_sae4096_100Kseqs_quantile/",
numHiddenDims: 4096,
plmLayer: 24,
curated: [
Expand Down
129 changes: 106 additions & 23 deletions viz/src/SAEVisualizerPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,79 @@ import { useEffect, useState, useContext } from "react";
import MolstarMulti from "./components/MolstarMulti";
import CustomSeqPlayground from "./components/CustomSeqPlayground";
import { Navigate } from "react-router-dom";
import { Button } from "@/components/ui/button";
import proteinEmoji from "./protein.png";

import { SAEContext } from "./SAEContext";
import { NUM_SEQS_TO_DISPLAY } from "./config";
import { CONTRIBUTORS } from "./SAEConfigs";
import SeqsViewer, { SeqWithSAEActs } from "./components/SeqsViewer";
import { tokensToSequence } from "./utils";

const actRanges: [number, number][] = [
[0, 0.25],
[0.25, 0.5],
[0.5, 0.75],
[0.75, 1],
];

const rangeNames: string[] = actRanges.map(([start, end]) => `${start}-${end}`);

const TOP_RANGE = rangeNames[rangeNames.length - 1];
const BOTTOM_RANGE = rangeNames[0];

interface VizFile {
ranges: {
[key in `${number}-${number}`]: {
examples: SeqWithSAEActs[];
};
};
freq_active: number;
n_seqs: number;
top_pfam: string[];
max_act: number;
}

interface FeatureStats {
freq_active: number;
top_pfam: string[];
}

const SAEVisualizerPage: React.FC = () => {
const { selectedFeature, selectedModel, SAEConfig } = useContext(SAEContext);
const dimToCuratedMap = new Map(SAEConfig?.curated?.map((i) => [i.dim, i]) || []);
const [featureStats, setFeatureStats] = useState<FeatureStats>();

const [topFeatureData, setTopFeatureData] = useState<SeqWithSAEActs[]>([]);
const [bottomFeatureData, setBottomFeatureData] = useState<SeqWithSAEActs[]>([]);
// Toggle for showing the bottom Molstar
const [showBottomMolstar, setShowBottomMolstar] = useState(false);
const [isLoading, setIsLoading] = useState(true);

const [featureData, setFeatureData] = useState<SeqWithSAEActs[]>([]);
useEffect(() => {
const fileURL = `${SAEConfig.baseUrl}${selectedFeature}.json`;
// Reset the bottom Molstar visibility
setFeatureStats(undefined);
setShowBottomMolstar(false);
setIsLoading(true);

fetch(fileURL)
.then((response) => response.json())
.then((data) => {
// NOTE(liam): important data transformation
setFeatureData(
data
.slice(0, NUM_SEQS_TO_DISPLAY)
.map(
(seq: {
tokens_acts_list: number[];
tokens_list: number[];
alphafold_id: string;
}) => ({
sae_acts: seq.tokens_acts_list,
sequence: tokensToSequence(seq.tokens_list),
alphafold_id: seq.alphafold_id,
})
)
);
.then((data: VizFile) => {
if (TOP_RANGE in data["ranges"]) {
const topQuarter = data["ranges"][TOP_RANGE as `${number}-${number}`];
const examples = topQuarter["examples"];
setTopFeatureData(examples.slice(0, NUM_SEQS_TO_DISPLAY));
setFeatureStats({
freq_active: data["freq_active"],
top_pfam: data["top_pfam"],
});
}
if (BOTTOM_RANGE in data["ranges"]) {
const bottomQuarter = data["ranges"][BOTTOM_RANGE as `${number}-${number}`];
const examples = bottomQuarter["examples"];
setBottomFeatureData(examples.slice(0, NUM_SEQS_TO_DISPLAY));
}
setIsLoading(false);
});
}, [SAEConfig, selectedFeature]);

Expand Down Expand Up @@ -67,11 +107,54 @@ const SAEVisualizerPage: React.FC = () => {
return (
<>
<main className="text-left max-w-full overflow-x-auto">
<h1 className="text-3xl font-semibold md:mt-0 mt-16">Feature {selectedFeature}</h1>
{dimToCuratedMap.has(selectedFeature) && <div className="mt-3">{desc}</div>}
<div className="flex justify-between items-center mt-3 mb-3">
<h1 className="text-3xl font-semibold md:mt-0 mt-16">Feature {selectedFeature}</h1>
{featureStats && (
<div>
Activation frequency: {(featureStats.freq_active * 100).toFixed(2)}%
</div>
)}
</div>
<div>
{featureStats && featureStats.top_pfam.length > 0 && (
<div>
Highly activating Pfams:{" "}
{featureStats.top_pfam.map((pfam) => (
<a
key={pfam}
href={`https://www.ebi.ac.uk/interpro/entry/pfam/${pfam}`}
target="_blank"
rel="noreferrer"
>
<span key={pfam} className="px-2 py-1 bg-gray-200 rounded-md mx-1">
{pfam}
</span>
</a>
))}
</div>
)}
</div>
<div className="mt-3">{dimToCuratedMap.has(selectedFeature) && desc}</div>
{SAEConfig?.supportsCustomSequence && <CustomSeqPlayground feature={selectedFeature} />}
<SeqsViewer seqs={featureData} />
<MolstarMulti proteins={featureData} />
{isLoading ? (
<div className="flex items-center justify-center w-full mt-5">
<img src={proteinEmoji} alt="Loading..." className="w-12 h-12 animate-wiggle mb-4" />
</div>
) : (
<>
<SeqsViewer seqs={topFeatureData} title={"Top activating sequences"} />
<MolstarMulti proteins={topFeatureData} />
<SeqsViewer seqs={bottomFeatureData} title={"Sequences with max activation < 0.25"} />
<Button
onClick={() => setShowBottomMolstar(!showBottomMolstar)}
variant="outline"
className="mb-3 mt-3"
>
{showBottomMolstar ? "Hide" : "Show"} structures
</Button>
{showBottomMolstar && <MolstarMulti proteins={bottomFeatureData} />}
</>
)}
</main>
</>
);
Expand Down
14 changes: 11 additions & 3 deletions viz/src/components/MolstarMulti.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,17 @@ const MolstarMulti: React.FC<MolstarViewerProps> = ({ proteins }) => {
alt={`Protein ${protein.alphafold_id}`}
className="w-full h-full object-cover"
/>
<div className="absolute bottom-2 left-2 bg-black bg-opacity-50 text-white px-2 py-1 rounded text-sm">
{protein.alphafold_id}
</div>
<a
href={`https://uniprot.org/uniprot/${protein.uniprot_id}`}
target="_blank"
rel="noreferrer"
>
<div className="absolute bottom-2 left-2 bg-black bg-opacity-50 text-white px-2 py-1 rounded text-sm">
{protein.name.length > 30
? protein.name.substring(0, 32) + "..."
: protein.name}
</div>
</a>
</TooltipTrigger>
<TooltipContent>Click to interact with the structure</TooltipContent>
</Tooltip>
Expand Down
9 changes: 6 additions & 3 deletions viz/src/components/SeqsViewer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ import { Skeleton } from "@/components/ui/skeleton";
export interface SeqWithSAEActs {
sequence: string;
sae_acts: Array<number>;
alphafold_id?: string;
alphafold_id: string;
uniprot_id: string;
name: string;
}

interface SeqsViewerProps {
seqs: SeqWithSAEActs[];
title: string;
}

export default function SeqsViewer({ seqs }: SeqsViewerProps) {
export default function SeqsViewer({ seqs, title }: SeqsViewerProps) {
const [alignmentMode, setAlignmentMode] = useState<"first_act" | "max_act" | "msa">("first_act");
const [alignedSeqs, setAlignedSeqs] = useState<SeqWithSAEActs[]>(seqs);
const [isAligning, setIsAligning] = useState(false);
Expand Down Expand Up @@ -164,7 +167,7 @@ export default function SeqsViewer({ seqs }: SeqsViewerProps) {
return (
<>
<div className="flex items-center gap-4 mt-8 justify-between flex-wrap">
<h2 className="text-2xl font-semibold">Top activating sequences</h2>
<h2 className="text-2xl font-semibold">{title}</h2>
<div className="hidden sm:flex items-center gap-4">
<ToggleGroup
type="single"
Expand Down

0 comments on commit 2e21bdc

Please sign in to comment.