Skip to content

Commit

Permalink
Improve
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Oct 14, 2024
1 parent 11a745d commit e2692d1
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 93 deletions.
180 changes: 88 additions & 92 deletions plm_interpretability/latent_probe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from plm_interpretability.sae_model import SparseAutoencoder
from plm_interpretability.utils import get_layer_activations, parse_swissprot_annotation

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -117,11 +115,13 @@ def get_sae_acts(
sae_model: SparseAutoencoder,
plm_layer: int,
) -> np.ndarray[float]:
"""
Returns a (len(seq), sae_dim) array of SAE activations.
"""
esm_layer_acts = get_layer_activations(
tokenizer=tokenizer, plm=plm_model, seqs=[seq], layer=plm_layer
)[0]
sae_acts = sae_model.get_acts(esm_layer_acts)[1:-1] # Trim BOS and EOS tokens
assert sae_acts.shape[0] == len(seq)
return sae_acts.cpu().numpy()


Expand Down Expand Up @@ -154,22 +154,16 @@ def get_annotation_entries_for_class(
entries = [e for e in entries if class_name in e.get("note", "")]
if len(entries) > 0:
seq_to_annotation_entries[seq] = entries
logger.info(
f"Found {len(seq_to_annotation_entries)} sequences with class {class_name}"
)
logger.info(f"Found {len(seq_to_annotation_entries)} sequences with class {class_name}")

if len(seq_to_annotation_entries) > max_seqs_per_task:
logger.warning(
f"Since max_seqs_per_task={max_seqs_per_task}, using a random "
f"sample of {max_seqs_per_task} sequences."
)
subset_seqs = random.sample(
list(seq_to_annotation_entries.keys()), max_seqs_per_task
)
subset_seqs = random.sample(list(seq_to_annotation_entries.keys()), max_seqs_per_task)
seq_to_annotation_entries = {
seq: entries
for seq, entries in seq_to_annotation_entries.items()
if seq in subset_seqs
seq: entries for seq, entries in seq_to_annotation_entries.items() if seq in subset_seqs
}

return seq_to_annotation_entries
Expand Down Expand Up @@ -244,7 +238,7 @@ def make_examples_from_annotation_entries(
if i in positive_positions:
num_positive_examples += 1

logger.info(f"Make {len(examples)} examples ({num_positive_examples} positive)")
logger.info(f"Made {len(examples)} examples ({num_positive_examples} positive)")
return examples


Expand All @@ -261,9 +255,7 @@ def run_logistic_regression_on_latent(args):

# Load data from memory-mapped files
X_train = np.memmap(X_train_filename, dtype="float32", mode="r", shape=shape_train)
y_train = np.memmap(
y_train_filename, dtype="bool", mode="r", shape=(shape_train[0],)
)
y_train = np.memmap(y_train_filename, dtype="bool", mode="r", shape=(shape_train[0],))
X_test = np.memmap(X_test_filename, dtype="float32", mode="r", shape=shape_test)
y_test = np.memmap(y_test_filename, dtype="bool", mode="r", shape=(shape_test[0],))

Expand All @@ -286,12 +278,8 @@ def run_logistic_regression_on_latent(args):
required=True,
help="Path to the SAE checkpoint file",
)
@click.option(
"--sae-dim", type=int, required=True, help="Dimension of the sparse autoencoder"
)
@click.option(
"--plm-dim", type=int, required=True, help="Dimension of the protein language model"
)
@click.option("--sae-dim", type=int, required=True, help="Dimension of the sparse autoencoder")
@click.option("--plm-dim", type=int, required=True, help="Dimension of the protein language model")
@click.option(
"--plm-layer",
type=int,
Expand All @@ -310,6 +298,12 @@ def run_logistic_regression_on_latent(args):
required=True,
help="Path to the output directory",
)
@click.option(
"--annotation-names",
type=click.STRING,
multiple=True,
help="List of annotation names to process. If not provided, all annotations will be processed.",
)
@click.option(
"--max-seqs-per-task",
type=int,
Expand All @@ -323,6 +317,7 @@ def latent_probe(
plm_layer: int,
swissprot_tsv: str,
output_dir: str,
annotation_names: list[str],
max_seqs_per_task: int,
):
"""
Expand All @@ -334,15 +329,16 @@ def latent_probe(

# Load pLM and SAE
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
plm_model = (
EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval()
)
plm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval()
sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device)
sae_model.load_state_dict(torch.load(sae_checkpoint, map_location=device))

df = pd.read_csv(swissprot_tsv, sep="\t")

for annotation in RESIDUE_ANNOTATIONS:
if annotation_names and annotation.name not in annotation_names:
continue

logger.info(f"Processing annotation: {annotation.name}")
os.makedirs(os.path.join(output_dir, annotation.name), exist_ok=True)

Expand Down Expand Up @@ -377,77 +373,77 @@ def latent_probe(
# This is expected for most dimensions.
warnings.simplefilter("ignore")

X_train = np.array(
[e["sae_acts"] for e in train_examples], dtype="float32"
)
X_train = np.array([e["sae_acts"] for e in train_examples], dtype="float32")
y_train = np.array([e["target"] for e in train_examples], dtype="bool")
X_test = np.array(
[e["sae_acts"] for e in test_examples], dtype="float32"
)
X_test = np.array([e["sae_acts"] for e in test_examples], dtype="float32")
y_test = np.array([e["target"] for e in test_examples], dtype="bool")

# Create memory-mapped files
temp_dir = tempfile.mkdtemp()
X_train_filename = os.path.join(temp_dir, "X_train.dat")
y_train_filename = os.path.join(temp_dir, "y_train.dat")
X_test_filename = os.path.join(temp_dir, "X_test.dat")
y_test_filename = os.path.join(temp_dir, "y_test.dat")

X_train_mmap = np.memmap(
X_train_filename, dtype="float32", mode="w+", shape=X_train.shape
)
y_train_mmap = np.memmap(
y_train_filename, dtype="bool", mode="w+", shape=y_train.shape
)
X_test_mmap = np.memmap(
X_test_filename, dtype="float32", mode="w+", shape=X_test.shape
)
y_test_mmap = np.memmap(
y_test_filename, dtype="bool", mode="w+", shape=y_test.shape
)

X_train_mmap[:] = X_train[:]
y_train_mmap[:] = y_train[:]
X_test_mmap[:] = X_test[:]
y_test_mmap[:] = y_test[:]

X_train_mmap.flush()
y_train_mmap.flush()
X_test_mmap.flush()
y_test_mmap.flush()

# Create arguments for each process
args = [
(
dim,
X_train_filename,
y_train_filename,
X_test_filename,
y_test_filename,
X_train.shape,
X_test.shape,
)
for dim in range(sae_dim)
]

with Pool() as pool:
res_rows = list(
tqdm(
pool.imap(run_logistic_regression_on_latent, args),
total=sae_dim,
desc=(
"Logistic regression on each latent dimension for "
f"{annotation.name}: {class_name}"
),
)
)

# Clean up temporary files
os.unlink(X_train_filename)
os.unlink(y_train_filename)
os.unlink(X_test_filename)
os.unlink(y_test_filename)
os.rmdir(temp_dir)
with tempfile.TemporaryDirectory() as temp_dir:
X_train_filename = os.path.join(temp_dir, "X_train.dat")
y_train_filename = os.path.join(temp_dir, "y_train.dat")
X_test_filename = os.path.join(temp_dir, "X_test.dat")
y_test_filename = os.path.join(temp_dir, "y_test.dat")

with (
np.memmap(
X_train_filename,
dtype="float32",
mode="w+",
shape=X_train.shape,
) as X_train_mmap,
np.memmap(
y_train_filename,
dtype="bool",
mode="w+",
shape=y_train.shape,
) as y_train_mmap,
np.memmap(
X_test_filename,
dtype="float32",
mode="w+",
shape=X_test.shape,
) as X_test_mmap,
np.memmap(
y_test_filename, dtype="bool", mode="w+", shape=y_test.shape
) as y_test_mmap,
):
X_train_mmap[:] = X_train[:]
y_train_mmap[:] = y_train[:]
X_test_mmap[:] = X_test[:]
y_test_mmap[:] = y_test[:]

X_train_mmap.flush()
y_train_mmap.flush()
X_test_mmap.flush()
y_test_mmap.flush()

# Create arguments for each process
args = [
(
dim,
X_train_filename,
y_train_filename,
X_test_filename,
y_test_filename,
X_train.shape,
X_test.shape,
)
for dim in range(sae_dim)
]

with Pool() as pool:
res_rows = list(
tqdm(
pool.imap(run_logistic_regression_on_latent, args),
total=sae_dim,
desc=(
"Running logistic regression on each latent for "
f"{annotation.name}: {class_name}"
),
)
)

res_df = pd.DataFrame(
res_rows, columns=["dim", "precision", "recall", "f1"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[tool.ruff]
line-length = 88
line-length = 100
target-version = "py310"

[tool.ruff.lint]
Expand Down

0 comments on commit e2692d1

Please sign in to comment.