Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Oct 15, 2024
1 parent e2692d1 commit 3c286dc
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4,168 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

This repo contains tools used to interpret protein language models. `viz` contains the frontend app for visualizing SAE features. `plm_interpretability` is a python package containing tools for SAE training and interpretation.

## Running the visualizer
## The visualizer

```bash
cd viz
pnpm install
pnpm run dev
```

## Running the auto-interpretation pipeline
## The python package

### Docker setup

Expand Down
120 changes: 51 additions & 69 deletions plm_interpretability/latent_probe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,15 @@ def make_examples_from_annotation_entries(
return examples


def run_logistic_regression_on_latent(args):
(
dim,
X_train_filename,
y_train_filename,
X_test_filename,
y_test_filename,
shape_train,
shape_test,
) = args

def run_logistic_regression_on_latent(
dim: int,
X_train_filename: str,
y_train_filename: str,
X_test_filename: str,
y_test_filename: str,
shape_train: tuple[int, int],
shape_test: tuple[int, int],
) -> tuple[int, float, float, float]:
# Load data from memory-mapped files
X_train = np.memmap(X_train_filename, dtype="float32", mode="r", shape=shape_train)
y_train = np.memmap(y_train_filename, dtype="bool", mode="r", shape=(shape_train[0],))
Expand Down Expand Up @@ -379,71 +377,55 @@ def latent_probe(
y_test = np.array([e["target"] for e in test_examples], dtype="bool")

# Create memory-mapped files
temp_dir = tempfile.mkdtemp()
with tempfile.TemporaryDirectory() as temp_dir:
X_train_filename = os.path.join(temp_dir, "X_train.dat")
y_train_filename = os.path.join(temp_dir, "y_train.dat")
X_test_filename = os.path.join(temp_dir, "X_test.dat")
y_test_filename = os.path.join(temp_dir, "y_test.dat")

with (
np.memmap(
X_train_filename,
dtype="float32",
mode="w+",
shape=X_train.shape,
) as X_train_mmap,
np.memmap(
y_train_filename,
dtype="bool",
mode="w+",
shape=y_train.shape,
) as y_train_mmap,
np.memmap(
X_test_filename,
dtype="float32",
mode="w+",
shape=X_test.shape,
) as X_test_mmap,
np.memmap(
y_test_filename, dtype="bool", mode="w+", shape=y_test.shape
) as y_test_mmap,
):
X_train_mmap[:] = X_train[:]
y_train_mmap[:] = y_train[:]
X_test_mmap[:] = X_test[:]
y_test_mmap[:] = y_test[:]

X_train_mmap.flush()
y_train_mmap.flush()
X_test_mmap.flush()
y_test_mmap.flush()

# Create arguments for each process
args = [
(
dim,
X_train_filename,
y_train_filename,
X_test_filename,
y_test_filename,
X_train.shape,
X_test.shape,
)
for dim in range(sae_dim)
]

with Pool() as pool:
res_rows = list(
tqdm(
pool.imap(run_logistic_regression_on_latent, args),
total=sae_dim,
desc=(
"Running logistic regression on each latent for "
f"{annotation.name}: {class_name}"
),
)
X_train_mmap = np.memmap(
X_train_filename, dtype="float32", mode="w+", shape=X_train.shape
)
y_train_mmap = np.memmap(
y_train_filename, dtype="bool", mode="w+", shape=y_train.shape
)
X_test_mmap = np.memmap(
X_test_filename, dtype="float32", mode="w+", shape=X_test.shape
)
y_test_mmap = np.memmap(
y_test_filename, dtype="bool", mode="w+", shape=y_test.shape
)

X_train_mmap[:] = X_train[:]
y_train_mmap[:] = y_train[:]
X_test_mmap[:] = X_test[:]
y_test_mmap[:] = y_test[:]

X_train_mmap.flush()
y_train_mmap.flush()
X_test_mmap.flush()
y_test_mmap.flush()

run_func = functools.partial(
run_logistic_regression_on_latent,
X_train_filename=X_train_filename,
y_train_filename=y_train_filename,
X_test_filename=X_test_filename,
y_test_filename=y_test_filename,
shape_train=X_train.shape,
shape_test=X_test.shape,
)
with Pool() as pool:
res_rows = list(
tqdm(
pool.imap(run_func, range(sae_dim)),
total=sae_dim,
desc=(
"Logistic regression on each latent dimension for "
f"{annotation.name}: {class_name}"
),
)
)

res_df = pd.DataFrame(
res_rows, columns=["dim", "precision", "recall", "f1"]
Expand Down
Loading

0 comments on commit 3c286dc

Please sign in to comment.