Skip to content

Commit

Permalink
Simplify navigation
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Dec 7, 2024
1 parent 32dd451 commit 716e95c
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 98 deletions.
67 changes: 12 additions & 55 deletions viz/src/SAEContext.tsx
Original file line number Diff line number Diff line change
@@ -1,76 +1,33 @@
import { createContext, useState, useEffect } from "react";
import { createContext } from "react";
import { SAE_CONFIGS, SAEConfig } from "./SAEConfigs";
import { DEFAULT_SAE_MODEL } from "./config";
import { SidebarProvider } from "@/components/ui/sidebar";
import { useNavigate } from "react-router-dom";
import SAESidebar from "./components/SAESidebar";
import { useParams } from "react-router-dom";

interface SAEContextType {
selectedModel: string;
setSelectedModel: (model: string) => void;
selectedFeature: number | undefined;
setSelectedFeature: (feature: number | undefined) => void;
model: string;
feature: number | undefined;
SAEConfig: SAEConfig;
}

export const SAEContext = createContext<SAEContextType>({
selectedModel: DEFAULT_SAE_MODEL,
setSelectedModel: () => {},
model: DEFAULT_SAE_MODEL,
SAEConfig: SAE_CONFIGS[DEFAULT_SAE_MODEL],
selectedFeature: SAE_CONFIGS[DEFAULT_SAE_MODEL].defaultDim,
setSelectedFeature: () => {},
feature: SAE_CONFIGS[DEFAULT_SAE_MODEL].defaultDim,
});

export const SAEProvider = ({ children }: { children: React.ReactNode }) => {
const path = window.location.hash.substring(1);

const modelMatch = path.match(/\/(?:sae-viz\/)?([^/]+)/);
const featureMatch = path.match(/\/(?:sae-viz\/)?[^/]+\/(\d+)/);

const urlModel = modelMatch ? modelMatch[1] : DEFAULT_SAE_MODEL;
const [selectedModel, setSelectedModel] = useState<string>(
SAE_CONFIGS[urlModel] ? urlModel : DEFAULT_SAE_MODEL
);

const urlFeature = featureMatch ? Number(featureMatch[1]) : undefined;
const [selectedFeature, setSelectedFeature] = useState<number | undefined>(urlFeature);

const navigate = useNavigate();

useEffect(() => {
if (urlFeature !== undefined && urlFeature !== selectedFeature) {
setSelectedFeature(urlFeature);
}
}, [urlFeature, selectedFeature]);
const { model: modelParam, feature: featureParam } = useParams();
const model = modelParam ? modelParam : DEFAULT_SAE_MODEL;
const feature = featureParam ? parseInt(featureParam) : undefined;

return (
<SAEContext.Provider
value={{
selectedModel: selectedModel,
setSelectedModel: (model: string) => {
setSelectedModel(model);
if (selectedFeature !== undefined) {
navigate(`/sae-viz/${model}/${selectedFeature}`);
} else {
navigate(`/sae-viz/${model}`);
}
},
selectedFeature: selectedFeature,
setSelectedFeature: (feature: number | undefined) => {
setSelectedFeature(feature);
if (feature !== undefined) {
const seqMatch = path.match(/\?seq=([^&]+)/);
const pdbMatch = path.match(/\?pdb=([^&]+)/);
const seq = seqMatch ? seqMatch[1] : "";
const pdb = pdbMatch ? pdbMatch[1] : "";
const queryParams = [];
if (seq) queryParams.push(`seq=${seq}`);
if (pdb) queryParams.push(`pdb=${pdb}`);
const queryString = queryParams.length ? `?${queryParams.join("&")}` : "";
navigate(`/sae-viz/${selectedModel}/${feature}${queryString}`);
}
},
SAEConfig: SAE_CONFIGS[selectedModel],
model: model,
feature: feature,
SAEConfig: SAE_CONFIGS[model],
}}
>
<SidebarProvider>
Expand Down
22 changes: 11 additions & 11 deletions viz/src/SAEVisualizerPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ const processData = (data: VizFile) => {
};

const SAEVisualizerPage: React.FC = () => {
const { selectedFeature, selectedModel, SAEConfig } = useContext(SAEContext);
const { feature, model, SAEConfig } = useContext(SAEContext);
const dimToCuratedMap = new Map(SAEConfig?.curated?.map((i) => [i.dim, i]) || []);
const [featureStats, setFeatureStats] = useState<FeatureStats>();

Expand All @@ -79,7 +79,7 @@ const SAEVisualizerPage: React.FC = () => {
const [isDeadLatent, setIsDeadLatent] = useState(false);

useEffect(() => {
const fileURL = `${SAEConfig.baseUrl}${selectedFeature}.json`;
const fileURL = `${SAEConfig.baseUrl}${feature}.json`;

const loadData = async () => {
setFeatureStats(undefined);
Expand All @@ -100,17 +100,17 @@ const SAEVisualizerPage: React.FC = () => {
};

loadData();
}, [SAEConfig, selectedFeature]);
}, [SAEConfig, feature]);

if (selectedFeature === undefined) {
return <Navigate to={`/sae-viz/${selectedModel}`} />;
if (feature === undefined) {
return <Navigate to={`/sae-viz/${model}`} />;
}
let desc = <>{dimToCuratedMap.get(selectedFeature)?.desc}</>;
const contributor = dimToCuratedMap.get(selectedFeature)?.contributor;
let desc = <>{dimToCuratedMap.get(feature)?.desc}</>;
const contributor = dimToCuratedMap.get(feature)?.contributor;
if (contributor && contributor in CONTRIBUTORS) {
desc = (
<div className="flex flex-col gap-2">
<p>{dimToCuratedMap.get(selectedFeature)?.desc}</p>
<p>{dimToCuratedMap.get(feature)?.desc}</p>
<p>
This feature was identified by{" "}
<a
Expand All @@ -131,7 +131,7 @@ const SAEVisualizerPage: React.FC = () => {
<>
<main className="text-left max-w-full overflow-x-auto w-full">
<div className="flex justify-between items-center mt-3 mb-3">
<h1 className="text-3xl font-semibold md:mt-0 mt-16">Feature {selectedFeature}</h1>
<h1 className="text-3xl font-semibold md:mt-0 mt-16">Feature {feature}</h1>
{featureStats && (
<div>Activation frequency: {(featureStats.freq_active * 100).toFixed(2)}%</div>
)}
Expand Down Expand Up @@ -159,9 +159,9 @@ const SAEVisualizerPage: React.FC = () => {
<div className="mt-3">This is a dead latent. It does not activate on any sequence.</div>
) : (
<>
<div className="mt-3">{dimToCuratedMap.has(selectedFeature) && desc}</div>
<div className="mt-3">{dimToCuratedMap.has(feature) && desc}</div>
{SAEConfig?.supportsCustomSequence && (
<CustomSeqPlayground feature={selectedFeature} saeName={selectedModel} />
<CustomSeqPlayground feature={feature} saeName={model} />
)}
{isLoading ? (
<div className="flex items-center justify-center w-full mt-5">
Expand Down
7 changes: 3 additions & 4 deletions viz/src/components/CustomSeqSearchPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import { isPDBID, isProteinSequence, AminoAcidSequence, getPDBChainsData } from
import { useUrlState } from "@/hooks/useUrlState";
import { SAEContext } from "@/SAEContext";
export default function CustomSeqSearchPage() {
const { selectedModel } = useContext(SAEContext);
const { model } = useContext(SAEContext);

const { urlInput, setUrlInput } = useUrlState();
const [input, setInput] = useState<string>("");
Expand Down Expand Up @@ -70,7 +70,6 @@ export default function CustomSeqSearchPage() {

const handleSearch = useCallback(
async (submittedInput: ValidSeqInput) => {
console.log("handling search");
setIsLoading(true);
setInput(submittedInput);

Expand All @@ -88,13 +87,13 @@ export default function CustomSeqSearchPage() {
}

submittedSeqRef.current = seq;
setSearchResults(await getSAEAllDimsActivations({ sequence: seq, sae_name: selectedModel }));
setSearchResults(await getSAEAllDimsActivations({ sequence: seq, sae_name: model }));
setIsLoading(false);

setStartPos(undefined);
setEndPos(undefined);
},
[selectedModel, setUrlInput]
[model, setUrlInput]
);

useEffect(() => {
Expand Down
30 changes: 14 additions & 16 deletions viz/src/components/SAEFeatureCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Card, CardContent, CardTitle, CardHeader, CardDescription } from "@/com
import { SAEContext } from "../SAEContext";
import FullSeqsViewer from "./FullSeqsViewer";
import { ProteinActivationsData } from "@/utils";
import { Link, useLocation } from "react-router-dom";

export default function SAEFeatureCard({
dim,
Expand All @@ -11,24 +12,21 @@ export default function SAEFeatureCard({
dim: number;
proteinActivationsData: ProteinActivationsData;
}) {
const { SAEConfig, setSelectedFeature } = useContext(SAEContext);
const { SAEConfig } = useContext(SAEContext);
const location = useLocation();
const desc = SAEConfig.curated?.find((f) => f.dim === dim)?.desc;

return (
<Card
key={dim}
className="cursor-pointer"
onClick={() => {
setSelectedFeature(dim);
}}
>
<CardHeader>
<CardTitle className="text-left">Feature {dim}</CardTitle>
{desc && <CardDescription>{desc}</CardDescription>}
</CardHeader>
<CardContent>
<FullSeqsViewer proteinActivationsData={proteinActivationsData} />
</CardContent>
</Card>
<Link to={`${dim}${location.search}`} className="block">
<Card key={dim} className="cursor-pointer">
<CardHeader>
<CardTitle className="text-left">Feature {dim}</CardTitle>
{desc && <CardDescription>{desc}</CardDescription>}
</CardHeader>
<CardContent>
<FullSeqsViewer proteinActivationsData={proteinActivationsData} />
</CardContent>
</Card>
</Link>
);
}
21 changes: 9 additions & 12 deletions viz/src/components/SAESidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@ import {
import { Separator } from "@/components/ui/separator";
import HomeNavigator from "@/components/HomeNavigator";
import { Toggle } from "@/components/ui/toggle";
import { useNavigate } from "react-router-dom";
import { Button } from "@/components/ui/button";
import { Dices, Search } from "lucide-react";
import { SAEContext } from "../SAEContext";
import Markdown from "markdown-to-jsx";
import { usePreserveQueryParamsNavigate } from "@/hooks/useNagivateWithQueryParams";

export default function SAESidebar() {
const { setOpenMobile } = useSidebar();
const navigate = useNavigate();
const { selectedModel, setSelectedModel, selectedFeature, setSelectedFeature, SAEConfig } =
useContext(SAEContext);
const navigate = usePreserveQueryParamsNavigate();
const { model, feature, SAEConfig } = useContext(SAEContext);

const handleFeatureChange = (feature: number) => {
setSelectedFeature(feature);
navigate(`/sae-viz/${model}/${feature}`);
setOpenMobile(false);
};

Expand All @@ -47,17 +46,16 @@ export default function SAESidebar() {
<>
<div className="fixed flex items-center justify-between top-0 w-full bg-background border-b border-border z-50 py-4 px-6 md:hidden left-0 right-0">
<SidebarTrigger />
<Search onClick={() => navigate(`/sae-viz/${selectedModel}`)} />
<Search onClick={() => navigate(`/sae-viz/${model}`)} />
</div>
<Sidebar>
<SidebarHeader>
<div className="m-3">
<HomeNavigator />
</div>
<Select
value={selectedModel}
value={model}
onValueChange={(value) => {
setSelectedModel(value);
navigate(`/sae-viz/${value}/${SAE_CONFIGS[value].defaultDim}`);
}}
>
Expand Down Expand Up @@ -100,8 +98,7 @@ export default function SAESidebar() {
variant="outline"
className="mb-3 mx-3 whitespace-normal text-left h-auto py-2"
onClick={() => {
navigate(`/sae-viz/${selectedModel}`);
setSelectedFeature(undefined);
navigate(`/sae-viz/${model}`);
setOpenMobile(false);
}}
>
Expand All @@ -122,7 +119,7 @@ export default function SAESidebar() {
key={`feature-${c.dim}`}
style={{ width: "100%", paddingLeft: 20, textAlign: "left" }}
className="justify-start"
pressed={selectedFeature === c.dim}
pressed={feature === c.dim}
onPressedChange={() => handleFeatureChange(c.dim)}
>
{c.name}
Expand All @@ -138,7 +135,7 @@ export default function SAESidebar() {
key={`feature-${i}`}
style={{ width: "100%", paddingLeft: 20 }}
className="justify-start"
pressed={selectedFeature === i}
pressed={feature === i}
onPressedChange={() => handleFeatureChange(i)}
>
{i}
Expand Down
10 changes: 10 additions & 0 deletions viz/src/hooks/useNagivateWithQueryParams.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { useLocation, useNavigate } from "react-router-dom";

export const usePreserveQueryParamsNavigate = () => {
const navigate = useNavigate();
const { search } = useLocation();

return (path: string) => {
navigate(path + search);
};
};

0 comments on commit 716e95c

Please sign in to comment.