Skip to content

Commit

Permalink
Add searching against all features (#22)
Browse files Browse the repository at this point in the history
* Add searching against all features

* Cleanup

* Style search bars and update backend

* Search

* Fixup

* Refactor runpod calling logic

* Refactor sequence input component

* Input sequence validation

* Touchups and mobile

* Touchup

* Add sort

* Add description

* Improve folding error

* UX stuff

* Add examples

* Add seq to feature url

* Clean
  • Loading branch information
liambai authored Nov 6, 2024
1 parent a55c023 commit cbe2934
Show file tree
Hide file tree
Showing 22 changed files with 948 additions and 323 deletions.
28 changes: 23 additions & 5 deletions interprot/sae_inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,23 +382,41 @@ def load_models(sae_checkpoint: str):


def handler(event):
logger.info(f"starting handler with event: {event}")
try:
input_data = event["input"]
seq = input_data["sequence"]
dim = input_data["dim"]
dim = input_data.get("dim")
_, esm_layer_acts = esm2_model.get_layer_activations(seq, 24)
esm_layer_acts = esm_layer_acts[0].float()
logger.info(f"esm_layer_acts: {esm_layer_acts.shape}")

sae_acts = sae_model.get_acts(esm_layer_acts)[1:-1]
logger.info(f"sae_acts: {sae_acts.shape}")
sae_dim_acts = sae_acts[:, dim].cpu().numpy()

data = {}
if dim is not None:
sae_dim_acts = sae_acts[:, dim].cpu().numpy()
data["tokens_acts_list"] = [round(float(act), 1) for act in sae_dim_acts]
else:
max_acts, _ = torch.max(sae_acts, dim=0)
sorted_dims = torch.argsort(max_acts, descending=True)
active_dims = sorted_dims[max_acts[sorted_dims] > 0]
sae_acts_by_active_dim = sae_acts[:, active_dims].cpu().numpy()

data["token_acts_list_by_active_dim"] = [
{
"dim": int(active_dims[dim_idx].item()),
"sae_acts": [
round(float(act), 1) for act in sae_acts_by_active_dim[:, dim_idx]
],
}
for dim_idx in range(sae_acts_by_active_dim.shape[1])
]

return {
"status": "success",
"data": {
"tokens_acts_list": [round(float(act), 1) for act in sae_dim_acts],
},
"data": data,
}
except Exception as e:
logger.error(f"Traceback: {traceback.format_exc()}")
Expand Down
33 changes: 29 additions & 4 deletions viz/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
import "./App.css";
import { createHashRouter, RouterProvider } from "react-router-dom";
import { createHashRouter, RouterProvider, Navigate } from "react-router-dom";
import LandingPage from "./components/LandingPage";
import SAEVisualizer from "./SAEVisualizer";
import SAEVisualizerPage from "./SAEVisualizerPage";
import ErrorBoundary from "./components/ErrorBoundary";
import CustomSeqSearchPage from "./components/CustomSeqSearchPage";
import { SAEProvider } from "./SAEContext";
import { DEFAULT_SAE_MODEL } from "./config";
import { SAE_CONFIGS } from "./SAEConfigs";

const router = createHashRouter([
{
path: "/",
element: <LandingPage />,
},
{
path: "/sae-viz/:model?/:feature?",
element: <SAEVisualizer />,
path: "/sae-viz",
element: (
<Navigate
to={`/sae-viz/${DEFAULT_SAE_MODEL}/${SAE_CONFIGS[DEFAULT_SAE_MODEL].defaultDim}`}
replace
/>
),
},
{
path: "/sae-viz/:model/",
element: (
<SAEProvider>
<CustomSeqSearchPage />
</SAEProvider>
),
},
{
path: "/sae-viz/:model/:feature",
element: (
<SAEProvider>
<SAEVisualizerPage />
</SAEProvider>
),
},
]);

Expand Down
74 changes: 74 additions & 0 deletions viz/src/SAEContext.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import { createContext, useState, useEffect } from "react";
import { SAE_CONFIGS, SAEConfig } from "./SAEConfigs";
import { DEFAULT_SAE_MODEL } from "./config";
import { SidebarProvider } from "@/components/ui/sidebar";
import { useNavigate } from "react-router-dom";
import SAESidebar from "./components/SAESidebar";

interface SAEContextType {
selectedModel: string;
setSelectedModel: (model: string) => void;
selectedFeature: number | undefined;
setSelectedFeature: (feature: number | undefined) => void;
SAEConfig: SAEConfig;
}

export const SAEContext = createContext<SAEContextType>({
selectedModel: DEFAULT_SAE_MODEL,
setSelectedModel: () => {},
SAEConfig: SAE_CONFIGS[DEFAULT_SAE_MODEL],
selectedFeature: SAE_CONFIGS[DEFAULT_SAE_MODEL].defaultDim,
setSelectedFeature: () => {},
});

export const SAEProvider = ({ children }: { children: React.ReactNode }) => {
const path = window.location.hash.substring(1);

const modelMatch = path.match(/\/(?:sae-viz\/)?([^/]+)/);
const featureMatch = path.match(/\/(?:sae-viz\/)?[^/]+\/(\d+)/);

const urlModel = modelMatch ? modelMatch[1] : DEFAULT_SAE_MODEL;
const [selectedModel, setSelectedModel] = useState<string>(
SAE_CONFIGS[urlModel] ? urlModel : DEFAULT_SAE_MODEL
);

const urlFeature = featureMatch ? Number(featureMatch[1]) : undefined;
const [selectedFeature, setSelectedFeature] = useState<number | undefined>(urlFeature);

const navigate = useNavigate();

useEffect(() => {
if (urlFeature !== undefined && urlFeature !== selectedFeature) {
setSelectedFeature(urlFeature);
}
}, [urlFeature, selectedFeature]);

return (
<SAEContext.Provider
value={{
selectedModel: selectedModel,
setSelectedModel: (model: string) => {
setSelectedModel(model);
if (selectedFeature !== undefined) {
navigate(`/sae-viz/${model}/${selectedFeature}`);
} else {
navigate(`/sae-viz/${model}`);
}
},
selectedFeature: selectedFeature,
setSelectedFeature: (feature: number | undefined) => {
setSelectedFeature(feature);
if (feature !== undefined) {
navigate(`/sae-viz/${selectedModel}/${feature}`);
}
},
SAEConfig: SAE_CONFIGS[selectedModel],
}}
>
<SidebarProvider>
<SAESidebar />
{children}
</SidebarProvider>
</SAEContext.Provider>
);
};
212 changes: 0 additions & 212 deletions viz/src/SAEVisualizer.tsx

This file was deleted.

Loading

0 comments on commit cbe2934

Please sign in to comment.