Skip to content

Commit

Permalink
Support multi-chain PDB (#27)
Browse files Browse the repository at this point in the history
* Support multi-chain PDB

* Clean

* Multi chain coloring

* Clean up data structure

* Touchups
  • Loading branch information
liambai authored Nov 17, 2024
1 parent 9860db4 commit ec92585
Show file tree
Hide file tree
Showing 12 changed files with 890 additions and 401 deletions.
9 changes: 7 additions & 2 deletions viz/src/SAEContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ export const SAEProvider = ({ children }: { children: React.ReactNode }) => {
setSelectedFeature(feature);
if (feature !== undefined) {
const seqMatch = path.match(/\?seq=([^&]+)/);
const pdbMatch = path.match(/\?pdb=([^&]+)/);
const seq = seqMatch ? seqMatch[1] : "";
const seqParam = seq ? `?seq=${seq}` : "";
navigate(`/sae-viz/${selectedModel}/${feature}${seqParam}`);
const pdb = pdbMatch ? pdbMatch[1] : "";
const queryParams = [];
if (seq) queryParams.push(`seq=${seq}`);
if (pdb) queryParams.push(`pdb=${pdb}`);
const queryString = queryParams.length ? `?${queryParams.join("&")}` : "";
navigate(`/sae-viz/${selectedModel}/${feature}${queryString}`);
}
},
SAEConfig: SAE_CONFIGS[selectedModel],
Expand Down
200 changes: 138 additions & 62 deletions viz/src/components/CustomSeqPlayground.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import { useState, useEffect, useRef, useCallback } from "react";
import { Button } from "@/components/ui/button";
import { Slider } from "@/components/ui/slider";
import { isPDBID, getPDBSequence } from "@/utils.ts";
import {
isPDBID,
isProteinSequence,
AminoAcidSequence,
ProteinActivationsData,
constructProteinActivationsDataFromSequence,
constructProteinActivationsDataFromPDBID,
} from "@/utils.ts";
import CustomStructureViewer from "./CustomStructureViewer";
import { getSAEDimActivations, getSteeredSequence } from "@/runpod.ts";
import SeqInput from "./SeqInput";
import SeqInput, { ValidSeqInput } from "./SeqInput";
import { useSearchParams } from "react-router-dom";
import FullSeqViewer from "./FullSeqViewer";
import FullSeqsViewer from "./FullSeqsViewer";
import PDBStructureViewer from "./PDBStructureViewer";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";

interface CustomSeqPlaygroundProps {
feature: number;
Expand All @@ -19,70 +28,100 @@ enum PlaygroundState {
}

const initialState = {
customSeqActivations: [] as number[],
customSeq: "",
inputProteinActivations: {} as ProteinActivationsData,
proteinInput: "",
playgroundState: PlaygroundState.IDLE,
steeredSeq: "",
steerMultiplier: 1,
steeredActivations: [] as number[],
} as const;

const CustomSeqPlayground = ({ feature }: CustomSeqPlaygroundProps) => {
const [customSeqActivations, setCustomSeqActivations] = useState<number[]>(
initialState.customSeqActivations
// [chainIndex][residueIndex] array to account for the fact that PDB structures may have multiple chains.
// If user is not inputting a PDB ID, this will always be a 1D array.
const [inputProteinActivations, setInputProteinActivations] = useState<ProteinActivationsData>(
initialState.inputProteinActivations
);
const [customSeq, setCustomSeq] = useState<string>(initialState.customSeq);
const [proteinInput, setCustomSeqInput] = useState<string>(initialState.proteinInput);
const submittedInputRef = useRef<ValidSeqInput | undefined>(undefined);

const [playgroundState, setViewerState] = useState<PlaygroundState>(initialState.playgroundState);
const [steeredSeq, setSteeredSeq] = useState<string>(initialState.steeredSeq);
const [steerMultiplier, setSteerMultiplier] = useState<number>(initialState.steerMultiplier);
const [steeredActivations, setSteeredActivations] = useState<number[]>(
initialState.steeredActivations
);
const submittedSeqRef = useRef<string>("");
const pdbIdRef = useRef<string | undefined>(undefined);

const [searchParams, setSearchParams] = useSearchParams();

// Reset all state when feature changes
useEffect(() => {
setCustomSeqActivations(initialState.customSeqActivations);
setCustomSeq(searchParams.get("seq") || initialState.customSeq);
setViewerState(initialState.playgroundState);
setSteeredSeq(initialState.steeredSeq);
setSteerMultiplier(initialState.steerMultiplier);
setSteeredActivations(initialState.steeredActivations);
}, [feature, searchParams]);
// Somewhat hacky way to enable URL <-> state syncing (there's probably a better way):
// - If URL changes (e.g. user navigates to a new shared link), set this to "url"
// and set state to the URL
// - If the state changes (e.g. user submits a new sequence), set this to "state"
// and set the URL
// Keeping track of this source of update makes it easier in useEffect to distinguish
// which case we're in and avoid circular updates.
const lastInputUpdateSource = useRef<"url" | "state" | null>(null);

const handleSubmit = useCallback(async () => {
setViewerState(PlaygroundState.LOADING_SAE_ACTIVATIONS);
const handleSubmit = useCallback(
async (submittedInput: ValidSeqInput) => {
lastInputUpdateSource.current = "state";
setViewerState(PlaygroundState.LOADING_SAE_ACTIVATIONS);

// Reset some states related to downstream actions
setCustomSeqActivations(initialState.customSeqActivations);
// Reset some states related to downstream actions
setInputProteinActivations(initialState.inputProteinActivations);
setSteeredSeq(initialState.steeredSeq);
setSteerMultiplier(initialState.steerMultiplier);
setSteeredActivations(initialState.steeredActivations);

submittedInputRef.current = submittedInput;
if (isPDBID(submittedInput)) {
setInputProteinActivations(
await constructProteinActivationsDataFromPDBID(submittedInput, feature)
);
setSearchParams({ pdb: submittedInput });
} else {
setInputProteinActivations(
await constructProteinActivationsDataFromSequence(submittedInput, feature)
);
setSearchParams({ seq: submittedInput });
}
},
[setSearchParams, feature]
);

// Reset some states when the user navigates to a new feature
useEffect(() => {
setInputProteinActivations(initialState.inputProteinActivations);
setViewerState(initialState.playgroundState);
setSteeredSeq(initialState.steeredSeq);
setSteerMultiplier(initialState.steerMultiplier);
setSteeredActivations(initialState.steeredActivations);

submittedSeqRef.current = customSeq;
setSearchParams({ seq: submittedSeqRef.current });

if (isPDBID(submittedSeqRef.current)) {
pdbIdRef.current = submittedSeqRef.current;
submittedSeqRef.current = await getPDBSequence(submittedSeqRef.current);
if (submittedInputRef.current) {
handleSubmit(submittedInputRef.current);
}
}, [feature, handleSubmit]);

const saeActivations = await getSAEDimActivations({
sequence: submittedSeqRef.current,
dim: feature,
});
setCustomSeqActivations(saeActivations);
}, [customSeq, setSearchParams, feature]);

// Automatically submit when seq URL param is present
useEffect(() => {
const urlPdbId = searchParams.get("pdb");
const urlSeq = searchParams.get("seq");
if (urlSeq && customSeq === urlSeq && customSeqActivations.length === 0) {
handleSubmit();

// If the last update was from the URL (e.g. user navigated to a new link), submit the sequence
// and update the state
if (lastInputUpdateSource.current !== "state") {
lastInputUpdateSource.current = "url";
if (urlPdbId && isPDBID(urlPdbId) && submittedInputRef.current !== urlPdbId) {
setCustomSeqInput(urlPdbId);
handleSubmit(urlPdbId);
} else if (urlSeq && isProteinSequence(urlSeq) && submittedInputRef.current !== urlSeq) {
setCustomSeqInput(urlSeq);
handleSubmit(urlSeq);
}
}
}, [searchParams, customSeq, customSeqActivations.length, handleSubmit]);

lastInputUpdateSource.current = null;
}, [searchParams, handleSubmit]);

const handleSteer = async () => {
setViewerState(PlaygroundState.LOADING_STEERED_SEQUENCE);
Expand All @@ -92,7 +131,7 @@ const CustomSeqPlayground = ({ feature }: CustomSeqPlaygroundProps) => {
setSteeredSeq(initialState.steeredSeq);

const steeredSeq = await getSteeredSequence({
sequence: submittedSeqRef.current,
sequence: submittedInputRef.current!, // Steering controls only appear after this ref is set
dim: feature,
multiplier: steerMultiplier,
});
Expand All @@ -106,42 +145,56 @@ const CustomSeqPlayground = ({ feature }: CustomSeqPlaygroundProps) => {
<div>
<div className="mt-5">
<SeqInput
sequence={customSeq}
setSequence={setCustomSeq}
input={proteinInput}
setInput={setCustomSeqInput}
onSubmit={handleSubmit}
loading={playgroundState === PlaygroundState.LOADING_SAE_ACTIVATIONS}
buttonText="Submit"
onClear={() => {
setCustomSeq("");
setCustomSeqInput("");
setSearchParams({});
}}
/>
</div>

{/* Once we have SAE activations, display sequence and structure */}
{customSeqActivations.length > 0 && (
{submittedInputRef.current && Object.keys(inputProteinActivations).length > 0 && (
<>
<div className="overflow-x-auto my-4">
<FullSeqViewer sequence={submittedSeqRef.current} activations={customSeqActivations} />
</div>
{customSeqActivations.every((act) => act === 0) && (
<Card className="my-4">
<CardHeader>
<CardTitle>SAE activations on input protein</CardTitle>
</CardHeader>
<CardContent>
<FullSeqsViewer proteinActivationsData={inputProteinActivations} />
</CardContent>
</Card>
{inputProteinActivations.chains.every((chain) =>
chain.activations.every((activation) => activation === 0)
) && (
<p className="text-sm mb-2">
This feature did not activate on your sequence. Try a sequence more similar to the
ones below.
</p>
)}
<CustomStructureViewer
viewerId="custom-viewer"
seq={submittedSeqRef.current}
pdbId={pdbIdRef.current}
activations={customSeqActivations}
onLoad={onStructureLoad}
/>

{isPDBID(submittedInputRef.current) ? (
<PDBStructureViewer
viewerId="custom-viewer"
proteinActivationsData={inputProteinActivations}
onLoad={onStructureLoad}
/>
) : (
<CustomStructureViewer
viewerId="custom-viewer"
proteinActivationsData={inputProteinActivations}
onLoad={onStructureLoad}
/>
)}
</>
)}

{/* Once we have SAE activations and the first structure has loaded, render the steering controls */}
{customSeqActivations.length > 0 &&
{Object.keys(inputProteinActivations).length > 0 &&
playgroundState !== PlaygroundState.LOADING_SAE_ACTIVATIONS && (
<div className="mt-5">
<h3 className="text-xl font-bold mb-4">Sequence Editing via Steering</h3>
Expand Down Expand Up @@ -227,13 +280,36 @@ const CustomSeqPlayground = ({ feature }: CustomSeqPlaygroundProps) => {

{steeredActivations.length > 0 && (
<>
<div className="overflow-x-auto my-4">
<FullSeqViewer sequence={steeredSeq} activations={steeredActivations} />
</div>
<Card className="my-4">
<CardHeader>
<CardTitle>SAE activations on steered protein</CardTitle>
</CardHeader>
<CardContent>
<FullSeqsViewer
proteinActivationsData={{
chains: [
{
id: "Unknown",
sequence: steeredSeq as AminoAcidSequence,
activations: steeredActivations,
},
],
}}
/>
</CardContent>
</Card>

<CustomStructureViewer
viewerId="steered-viewer"
seq={steeredSeq}
activations={steeredActivations}
proteinActivationsData={{
chains: [
{
id: "Unknown",
sequence: steeredSeq as AminoAcidSequence,
activations: steeredActivations,
},
],
}}
onLoad={onStructureLoad}
/>
</>
Expand Down
Loading

0 comments on commit ec92585

Please sign in to comment.