Skip to content

Commit

Permalink
memmap
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Oct 14, 2024
1 parent 016c743 commit 11a745d
Show file tree
Hide file tree
Showing 2 changed files with 4,182 additions and 28 deletions.
113 changes: 85 additions & 28 deletions plm_interpretability/latent_probe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import random
import tempfile
import warnings
from dataclasses import dataclass
from multiprocessing import Pool
Expand Down Expand Up @@ -243,20 +244,35 @@ def make_examples_from_annotation_entries(
if i in positive_positions:
num_positive_examples += 1

logger.info(f"{num_positive_examples}/{len(examples)} positive/total examples")
logger.info(f"Make {len(examples)} examples ({num_positive_examples} positive)")
return examples


def run_logistic_regression_on_latent(
dim: int,
X_train: list[list[float]],
y_train: list[bool],
X_test: list[list[float]],
y_test: list[bool],
) -> tuple[int, float, float, float]:
def run_logistic_regression_on_latent(args):
(
dim,
X_train_filename,
y_train_filename,
X_test_filename,
y_test_filename,
shape_train,
shape_test,
) = args

# Load data from memory-mapped files
X_train = np.memmap(X_train_filename, dtype="float32", mode="r", shape=shape_train)
y_train = np.memmap(
y_train_filename, dtype="bool", mode="r", shape=(shape_train[0],)
)
X_test = np.memmap(X_test_filename, dtype="float32", mode="r", shape=shape_test)
y_test = np.memmap(y_test_filename, dtype="bool", mode="r", shape=(shape_test[0],))

X_train_dim = X_train[:, dim].reshape(-1, 1)
X_test_dim = X_test[:, dim].reshape(-1, 1)

model = LogisticRegression(class_weight="balanced")
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
model.fit(X_train_dim, y_train)
y_pred = model.predict(X_test_dim)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
Expand Down Expand Up @@ -361,29 +377,63 @@ def latent_probe(
# This is expected for most dimensions.
warnings.simplefilter("ignore")

X_train = [
[e["sae_acts"][dim] for dim in range(sae_dim)]
for e in train_examples
]
y_train = [e["target"] for e in train_examples]
X_test = [
[e["sae_acts"][dim] for dim in range(sae_dim)]
for e in test_examples
]
y_test = [e["target"] for e in test_examples]

run_func = functools.partial(
run_logistic_regression_on_latent,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
X_train = np.array(
[e["sae_acts"] for e in train_examples], dtype="float32"
)
y_train = np.array([e["target"] for e in train_examples], dtype="bool")
X_test = np.array(
[e["sae_acts"] for e in test_examples], dtype="float32"
)
y_test = np.array([e["target"] for e in test_examples], dtype="bool")

# Create memory-mapped files
temp_dir = tempfile.mkdtemp()
X_train_filename = os.path.join(temp_dir, "X_train.dat")
y_train_filename = os.path.join(temp_dir, "y_train.dat")
X_test_filename = os.path.join(temp_dir, "X_test.dat")
y_test_filename = os.path.join(temp_dir, "y_test.dat")

X_train_mmap = np.memmap(
X_train_filename, dtype="float32", mode="w+", shape=X_train.shape
)
y_train_mmap = np.memmap(
y_train_filename, dtype="bool", mode="w+", shape=y_train.shape
)
X_test_mmap = np.memmap(
X_test_filename, dtype="float32", mode="w+", shape=X_test.shape
)
y_test_mmap = np.memmap(
y_test_filename, dtype="bool", mode="w+", shape=y_test.shape
)

X_train_mmap[:] = X_train[:]
y_train_mmap[:] = y_train[:]
X_test_mmap[:] = X_test[:]
y_test_mmap[:] = y_test[:]

X_train_mmap.flush()
y_train_mmap.flush()
X_test_mmap.flush()
y_test_mmap.flush()

# Create arguments for each process
args = [
(
dim,
X_train_filename,
y_train_filename,
X_test_filename,
y_test_filename,
X_train.shape,
X_test.shape,
)
for dim in range(sae_dim)
]

with Pool() as pool:
res_rows = list(
tqdm(
pool.imap(run_func, range(sae_dim)),
pool.imap(run_logistic_regression_on_latent, args),
total=sae_dim,
desc=(
"Logistic regression on each latent dimension for "
Expand All @@ -392,6 +442,13 @@ def latent_probe(
)
)

# Clean up temporary files
os.unlink(X_train_filename)
os.unlink(y_train_filename)
os.unlink(X_test_filename)
os.unlink(y_test_filename)
os.rmdir(temp_dir)

res_df = pd.DataFrame(
res_rows, columns=["dim", "precision", "recall", "f1"]
).sort_values(by="f1", ascending=False)
Expand Down
Loading

0 comments on commit 11a745d

Please sign in to comment.