Skip to content

Commit

Permalink
Improve search for antibodies (#34)
Browse files Browse the repository at this point in the history
* Improve search for antibodies

* Chain selector

* Curate
  • Loading branch information
liambai authored Dec 15, 2024
1 parent 370f84b commit 61a4638
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 69 deletions.
24 changes: 24 additions & 0 deletions viz/src/SAEConfigs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,30 @@ export const SAE_CONFIGS: Record<string, SAEConfig> = {
defaultDim: 0,
supportsCustomSequence: true,
curated: [
{
name: "H2",
dim: 2817,
desc: "Activates on the H2 CDR loop",
group: "CDR",
},
{
name: "H1 start",
dim: 3369,
desc: "Activates on the first amino acid in the H1 CDR loop",
group: "CDR",
},
{
name: "H1 end",
dim: 2295,
desc: "Activates on the last amino acid in the H1 CDR loop",
group: "CDR",
},
{
name: "H3 start",
dim: 923,
desc: "Activates on the first amino acid in the H3 CDR loop",
group: "CDR",
},
{
name: "beta sheet alternating",
dim: 305,
Expand Down
200 changes: 139 additions & 61 deletions viz/src/components/CustomSeqSearchPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ import { getSAEAllDimsActivations } from "@/runpod.ts";
import SeqInput, { ValidSeqInput } from "./SeqInput";
import { EXAMPLE_SEQS_FOR_SEARCH } from "./ui/ExampleSeqsForSearch";
import { Input } from "@/components/ui/input";
import { isPDBID, isProteinSequence, AminoAcidSequence, getPDBChainsData } from "@/utils";
import {
isPDBID,
isProteinSequence,
AminoAcidSequence,
getPDBChainsData,
PDBChainsData,
} from "@/utils";
import { useUrlState } from "@/hooks/useUrlState";
import { SAEContext } from "@/SAEContext";
import {
Expand Down Expand Up @@ -53,8 +59,6 @@ export default function CustomSeqSearchPage() {
const [startPos, setStartPos] = useState<number | undefined>();
const [endPos, setEndPos] = useState<number | undefined>();

const [warning, setWarning] = useState<string | undefined>(undefined);

const [minPercentActivation, setMinPercentActivation] = useState<number | undefined>();
const [maxPercentActivation, setMaxPercentActivation] = useState<number | undefined>(
DEFAULT_MAX_PERCENT_ACTIVATION
Expand All @@ -69,12 +73,14 @@ export default function CustomSeqSearchPage() {

const [isFilterOpen, setIsFilterOpen] = useState(false);

const [chains, setChains] = useState<PDBChainsData[]>([]);
const [selectedChain, setSelectedChain] = useState<string>("");

useEffect(() => {
setTempStartPos(startPos);
setTempEndPos(endPos);
setTempMinPercentActivation(minPercentActivation);
setTempMaxPercentActivation(maxPercentActivation);
setWarning(undefined);
}, [startPos, endPos, minPercentActivation, maxPercentActivation]);

const applyFilters = () => {
Expand All @@ -101,9 +107,7 @@ export default function CustomSeqSearchPage() {

const filteredResults = searchResults.filter((result) => {
if (!startPos && !endPos && !minPercentActivation && !maxPercentActivation) return true;

const hasActivationInRange = result.sae_acts.some((act, index) => {
const pos = index + 1;
const hasActivationInRange = result.sae_acts.some((act, pos) => {
const afterStart = !startPos || pos >= startPos;
const beforeEnd = !endPos || pos <= endPos;
return act > 0 && afterStart && beforeEnd;
Expand Down Expand Up @@ -137,29 +141,49 @@ export default function CustomSeqSearchPage() {
setIsLoading(true);
setInput(submittedInput);

let warning = "";
let seq: AminoAcidSequence;
if (isPDBID(submittedInput)) {
const pdbChainsData = await getPDBChainsData(submittedInput);
if (pdbChainsData.length > 1) {
warning = "PDB entry contains multiple chains. Only the first chain is considered.";
setChains(pdbChainsData);

// If there's only one chain, use it directly
if (pdbChainsData.length === 1) {
submittedSeqRef.current = pdbChainsData[0].sequence;
setSearchResults(
await getSAEAllDimsActivations({
sequence: pdbChainsData[0].sequence,
sae_name: model,
})
);
setUrlInput("pdb", submittedInput);
} else {
// If there are multiple chains, set the first one as default
setSelectedChain(pdbChainsData[0].id);
submittedSeqRef.current = pdbChainsData[0].sequence;
setSearchResults(
await getSAEAllDimsActivations({
sequence: pdbChainsData[0].sequence,
sae_name: model,
})
);
setUrlInput("pdb", submittedInput);
}
seq = pdbChainsData[0].sequence;
setUrlInput("pdb", submittedInput);
} else {
seq = submittedInput;
submittedSeqRef.current = submittedInput;
setSearchResults(
await getSAEAllDimsActivations({
sequence: submittedInput,
sae_name: model,
})
);
setUrlInput("seq", submittedInput);
setChains([]);
}

submittedSeqRef.current = seq;
setSearchResults(await getSAEAllDimsActivations({ sequence: seq, sae_name: model }));
setIsLoading(false);

setStartPos(undefined);
setEndPos(undefined);
setMinPercentActivation(undefined);
setMaxPercentActivation(DEFAULT_MAX_PERCENT_ACTIVATION);
setWarning(warning);
},
[model, setUrlInput]
);
Expand All @@ -174,6 +198,69 @@ export default function CustomSeqSearchPage() {
}
}, [urlInput, handleSearch]);

const sortResults = useCallback(
(results: Array<{ dim: number; sae_acts: number[] }>) => {
const sortedResults = [...results];
const start = startPos ?? 0;
const end = endPos ?? results[0]?.sae_acts.length ?? 0;

switch (sortBy) {
case "max":
sortedResults.sort((a, b) => {
const maxA = Math.max(...a.sae_acts.slice(start, end + 1));
const maxB = Math.max(...b.sae_acts.slice(start, end + 1));
return maxB - maxA;
});
break;
case "mean":
sortedResults.sort((a, b) => {
const sliceA = a.sae_acts.slice(start, end + 1);
const sliceB = b.sae_acts.slice(start, end + 1);
const meanA = sliceA.reduce((sum, val) => sum + val, 0) / sliceA.length;
const meanB = sliceB.reduce((sum, val) => sum + val, 0) / sliceB.length;
return meanB - meanA;
});
break;
case "mean_activated":
sortedResults.sort((a, b) => {
const sliceA = a.sae_acts.slice(start, end + 1);
const sliceB = b.sae_acts.slice(start, end + 1);
const activatedA = sliceA.filter((val) => val > 0);
const activatedB = sliceB.filter((val) => val > 0);
const meanA = activatedA.length
? activatedA.reduce((sum, val) => sum + val, 0) / activatedA.length
: 0;
const meanB = activatedB.length
? activatedB.reduce((sum, val) => sum + val, 0) / activatedB.length
: 0;
return meanB - meanA;
});
break;
}
return sortedResults;
},
[sortBy, startPos, endPos]
);

useEffect(() => {
if (searchResults.length > 0) {
setSearchResults((prevResults) => sortResults(prevResults));
}
}, [startPos, endPos, sortResults, searchResults.length]);

useEffect(() => {
if (selectedChain && chains.length > 0) {
const chain = chains.find((c) => c.id === selectedChain);
if (chain) {
submittedSeqRef.current = chain.sequence;
getSAEAllDimsActivations({
sequence: chain.sequence,
sae_name: model,
}).then(setSearchResults);
}
}
}, [selectedChain, chains, model]);

return (
<main
className={`min-h-screen w-full overflow-x-hidden ${
Expand All @@ -196,7 +283,22 @@ export default function CustomSeqSearchPage() {
</div>

<div className="flex flex-col gap-2 mt-4 text-left">
{warning && <div className="text-sm text-yellow-500">{warning}</div>}
{chains.length > 1 && (
<div className="flex items-center gap-2 mb-4">
<Select value={selectedChain} onValueChange={setSelectedChain}>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder="Select chain" />
</SelectTrigger>
<SelectContent>
{chains.map((chain) => (
<SelectItem key={chain.id} value={chain.id}>
Chain {chain.id}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
)}

{searchResults.length > 0 && (
<>
Expand All @@ -223,15 +325,20 @@ export default function CustomSeqSearchPage() {

<div className="p-4 space-y-4">
<div className="space-y-2">
<label className="text-sm font-medium">position range</label>
<label className="text-sm font-medium">
position range{" "}
<div className="font-normal text-muted-foreground mt-1">
(inclusive, will show up as bold in results)
</div>
</label>
<div className="flex items-center gap-2">
<Input
type="number"
className="w-24 text-sm"
placeholder="start"
min={1}
min={0}
max={submittedSeqRef.current?.length}
value={tempStartPos || ""}
value={tempStartPos !== undefined ? tempStartPos : ""}
onChange={(e) => {
const val = e.target.value ? parseInt(e.target.value) : undefined;
setTempStartPos(val);
Expand All @@ -242,9 +349,9 @@ export default function CustomSeqSearchPage() {
type="number"
className="w-24 text-sm"
placeholder="end"
min={1}
min={0}
max={submittedSeqRef.current?.length}
value={tempEndPos || ""}
value={tempEndPos !== undefined ? tempEndPos : ""}
onChange={(e) => {
const val = e.target.value ? parseInt(e.target.value) : undefined;
setTempEndPos(val);
Expand Down Expand Up @@ -308,45 +415,14 @@ export default function CustomSeqSearchPage() {
onValueChange={(value) => {
setSortBy(value);
setCurrentPage(1);
setSearchResults((prevResults) => {
const sortedResults = [...prevResults];
switch (value) {
case "max":
sortedResults.sort(
(a, b) => Math.max(...b.sae_acts) - Math.max(...a.sae_acts)
);
break;
case "mean":
sortedResults.sort((a, b) => {
const meanA =
a.sae_acts.reduce((sum, val) => sum + val, 0) / a.sae_acts.length;
const meanB =
b.sae_acts.reduce((sum, val) => sum + val, 0) / b.sae_acts.length;
return meanB - meanA;
});
break;
case "mean_activated":
sortedResults.sort((a, b) => {
const activatedA = a.sae_acts.filter((val) => val > 0);
const activatedB = b.sae_acts.filter((val) => val > 0);
const meanA = activatedA.length
? activatedA.reduce((sum, val) => sum + val, 0) /
activatedA.length
: 0;
const meanB = activatedB.length
? activatedB.reduce((sum, val) => sum + val, 0) /
activatedB.length
: 0;
return meanB - meanA;
});
break;
}
return sortedResults;
});
setSearchResults((prevResults) => sortResults(prevResults));
}}
>
<SelectTrigger className="w-full sm:w-[320px]">
<SelectValue placeholder="Sort by..." />
<SelectTrigger className="w-full sm:w-[350px]">
<div className="flex items-center gap-2">
<span className="text-sm text-muted-foreground">Sort by:</span>
<SelectValue placeholder="Choose sorting method" />
</div>
</SelectTrigger>
<SelectContent>
<SelectItem value="max">max activation across</SelectItem>
Expand Down Expand Up @@ -374,6 +450,8 @@ export default function CustomSeqSearchPage() {
},
],
}}
highlightStart={startPos}
highlightEnd={endPos}
/>
))}
<Pagination>
Expand Down
23 changes: 18 additions & 5 deletions viz/src/components/FullSeqsViewer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ import { Separator } from "@/components/ui/separator";

interface FullSeqViewerProps {
proteinActivationsData: ProteinActivationsData;
highlightStart?: number;
highlightEnd?: number;
}

const FullSeqsViewer: React.FC<FullSeqViewerProps> = ({ proteinActivationsData }) => {
const FullSeqsViewer: React.FC<FullSeqViewerProps> = ({
proteinActivationsData,
highlightStart,
highlightEnd,
}) => {
const maxValue = useMemo(
() => Math.max(...proteinActivationsData.chains.map((chain) => Math.max(...chain.activations))),
[proteinActivationsData]
Expand Down Expand Up @@ -48,10 +54,16 @@ const FullSeqsViewer: React.FC<FullSeqViewerProps> = ({ proteinActivationsData }
fontFamily: "monospace",
}}
>
{chain.sequence.split("").map((char, index) => {
const color = redColorMapHex(chain.activations[index], maxValue);
{chain.sequence.split("").map((char, pos) => {
const color = redColorMapHex(chain.activations[pos], maxValue);
const isHighlighted =
highlightStart === undefined
? pos <= (highlightEnd ?? -1)
: highlightEnd === undefined
? pos >= highlightStart
: pos >= highlightStart && pos <= highlightEnd;
return (
<TooltipProvider key={`token-${index}`} delayDuration={100}>
<TooltipProvider key={`token-${pos}`} delayDuration={100}>
<Tooltip>
<TooltipTrigger>
<span
Expand All @@ -61,13 +73,14 @@ const FullSeqsViewer: React.FC<FullSeqViewerProps> = ({ proteinActivationsData }
display: "inline-block",
width: "10px",
textAlign: "center",
fontWeight: isHighlighted ? "bold" : "normal",
}}
>
{char}
</span>
</TooltipTrigger>
<TooltipContent>
Position: {index}, SAE Activation: {chain.activations[index]?.toFixed(3)}
Position: {pos}, SAE Activation: {chain.activations[pos]?.toFixed(3)}
</TooltipContent>
</Tooltip>
</TooltipProvider>
Expand Down
4 changes: 2 additions & 2 deletions viz/src/components/PDBStructureViewer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ const PDBStructureViewer = ({
StructureCache[proteinActivationsData.pdbId] = pdbData;
renderViewer(pdbData);
} catch (error) {
console.error("Error folding sequence:", error);
setError("An error occurred while folding the sequence with ESMFold.");
console.error("Error loading structure:", error);
setError("An error occurred while loading the structure from PDB.");
}
};

Expand Down
Loading

0 comments on commit 61a4638

Please sign in to comment.