Skip to content

Commit

Permalink
New viz file with ranges (#28)
Browse files Browse the repository at this point in the history
* fix: sequence

* family info

* fix: uniprot id

* fix: max
  • Loading branch information
etowahadams authored Nov 28, 2024
1 parent 2e21bdc commit bed6c17
Showing 1 changed file with 131 additions and 82 deletions.
213 changes: 131 additions & 82 deletions interprot/make_viz_files/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str):
)

df = pl.read_parquet(sequences_file)
has_interpro = "InterPro" in df.columns
if has_interpro:
df = df.with_columns(
pl.col("InterPro").str.split(";").alias("interpro_ids")
)
has_pfam = "Pfam" in df.columns

# Pre-allocate numpy array for storing max activations
all_seqs_max_act = np.zeros((sae_dim, len(df)))
Expand All @@ -107,6 +103,7 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str):
sae_acts_cpu = sae_acts.cpu().numpy()
all_seqs_max_act[:, seq_idx] = np.max(sae_acts_cpu, axis=0)
sae_acts_int = (sae_acts_cpu * 10).astype(np.uint8)
# Convert to sparse matrix. This significantly reduces memory usage
sparse_acts = sparse.csr_matrix(sae_acts_int)
all_acts[seq_idx] = sparse_acts
# Clear CUDA cache periodically
Expand All @@ -119,97 +116,149 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str):

hidden_dim_to_seqs = {dim: {} for dim in range(sae_dim)}

# Calculate the top sequences for each hidden dimension and quartile
quartile_names = ["Q1", "Q2", "Q3", "Q4"]
act_ranges = [[0, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1]]
range_names = [f"{start}-{end}" for start, end in act_ranges]

for dim in tqdm(
range(sae_dim), desc="Finding highest activating seqs (Step 2/3)"
):
dim_maxes = all_seqs_max_act[dim]
if np.all(dim_maxes == 0):
non_zero_maxes = dim_maxes[dim_maxes > 0]

if len(non_zero_maxes) == 0:
print(f"Skipping dimension {dim} as it has no activations")
continue

# Get top Pfam families for sequences with activations greater than 0.75
if has_pfam:
top_families = get_top_pfam(
df, dim_maxes, act_gt=0.75, n_classes=3, frac_above_threshold=0.8
)
hidden_dim_to_seqs[dim]["top_pfam"] = top_families

non_zero_maxes = dim_maxes[dim_maxes > 0]
hidden_dim_to_seqs[dim]["freq_activate_among_all_seqs"] = len(
non_zero_maxes
) / len(dim_maxes)
quartiles = np.percentile(non_zero_maxes, [25, 50, 75])

q1_mask = (dim_maxes > 0) & (dim_maxes <= quartiles[0])
q2_mask = (dim_maxes > quartiles[0]) & (dim_maxes <= quartiles[1])
q3_mask = (dim_maxes > quartiles[1]) & (dim_maxes <= quartiles[2])
q4_mask = dim_maxes > quartiles[2]

quartile_indices = [
np.where(mask)[0] for mask in [q1_mask, q2_mask, q3_mask, q4_mask]
]
for q_name, q_indices in zip(quartile_names, quartile_indices):
hidden_dim_to_seqs[dim]["freq_active"] = len(non_zero_maxes) / len(
dim_maxes
)
hidden_dim_to_seqs[dim]["n_seqs"] = len(non_zero_maxes)
hidden_dim_to_seqs[dim]["max_act"] = float(dim_maxes.max())

normalized_acts = dim_maxes / dim_maxes.max()
for i, (start, end) in enumerate(act_ranges):
mask = (normalized_acts > start) & (normalized_acts <= end)
top_indices = heapq.nlargest(
NUM_SEQS_PER_DIM, q_indices, key=lambda i: dim_maxes[i]
NUM_SEQS_PER_DIM, np.where(mask)[0], key=lambda i: dim_maxes[i]
)
hidden_dim_to_seqs[dim][q_name] = {}
hidden_dim_to_seqs[dim][q_name]["n_seqs"] = len(q_indices)
hidden_dim_to_seqs[dim][q_name]["indices"] = top_indices
if has_interpro:
hidden_dim_to_seqs[dim][q_name]["interpro"] = get_top_interpro(
df, q_indices, top_n=10
)
range_name = range_names[i]
hidden_dim_to_seqs[dim][range_name] = {}
hidden_dim_to_seqs[dim][range_name]["indices"] = top_indices

for dim in tqdm(range(sae_dim), desc="Writing visualization files (Step 3/3)"):
viz_file = {"quartiles": {}}
if "freq_activate_among_all_seqs" in hidden_dim_to_seqs[dim]:
viz_file["freq_activate_among_all_seqs"] = hidden_dim_to_seqs[dim][
"freq_activate_among_all_seqs"
]
for quartile in quartile_names:
if quartile not in hidden_dim_to_seqs[dim]:
continue
quartile_examples = {
"examples": [],
"n_seqs": hidden_dim_to_seqs[dim][quartile]["n_seqs"],
}
if has_interpro:
quartile_examples["interpro"] = hidden_dim_to_seqs[dim][quartile][
"interpro"
]
quartile_indices = hidden_dim_to_seqs[dim][quartile]["indices"]

for seq_idx in quartile_indices:
seq_idx = int(seq_idx)
sae_acts = all_acts[seq_idx].toarray()
dim_acts = sae_acts[:, dim]
uniprot_id = df[seq_idx]["Entry"].item()[:-1]
alphafolddb_id = df[seq_idx]["AlphaFoldDB"].item().split(";")[0]
protein_name = df[seq_idx]["Protein names"].item()

examples = {
"sae_acts": [round(float(act) / 10, 1) for act in dim_acts],
"sequence": seq,
"alphafold_id": alphafolddb_id,
"uniprot_id": uniprot_id,
"name": protein_name,
}
quartile_examples["examples"].append(examples)

viz_file["quartiles"][quartile] = quartile_examples

with open(os.path.join(OUTPUT_ROOT_DIR, f"{dim}.json"), "w") as f:
json.dump(viz_file, f)


def get_top_interpro(original_df, indices, top_n=5):
df = original_df[indices]
total_rows = len(df)
counts = (
df.explode("interpro_ids")["interpro_ids"]
if not hidden_dim_to_seqs[dim]:
print(f"Skipping dimension {dim} as it has no sequences")
continue
write_viz_file(hidden_dim_to_seqs[dim], dim, all_acts, df, range_names)


def write_viz_file(dim_info, dim, all_acts, df, range_names):
viz_file = {"ranges": {}}
# Write how common the dimension is
if "freq_active" in dim_info:
viz_file["freq_active"] = dim_info["freq_active"]
if "n_seqs" in dim_info:
viz_file["n_seqs"] = dim_info["n_seqs"]
if "top_pfam" in dim_info:
viz_file["top_pfam"] = dim_info["top_pfam"]
if "max_act" in dim_info:
viz_file["max_act"] = dim_info["max_act"]

for range_name in range_names:
if range_name not in dim_info:
continue
range_examples = {
"examples": [],
}
top_indices = dim_info[range_name]["indices"]

for seq_idx in top_indices:
seq_idx = int(seq_idx)
sae_acts = all_acts[seq_idx].toarray()
dim_acts = sae_acts[:, dim]
uniprot_id = df[seq_idx]["Entry"].item()
alphafolddb_id = df[seq_idx]["AlphaFoldDB"].item().split(";")[0]
protein_name = df[seq_idx]["Protein names"].item()
sequence = df[seq_idx]["Sequence"].item()

examples = {
"sae_acts": [round(float(act) / 10, 1) for act in dim_acts],
"sequence": sequence,
"alphafold_id": alphafolddb_id,
"uniprot_id": uniprot_id,
"name": protein_name,
}
range_examples["examples"].append(examples)

viz_file["ranges"][range_name] = range_examples

with open(os.path.join(OUTPUT_ROOT_DIR, f"{dim}.json"), "w") as f:
json.dump(viz_file, f)



def get_top_pfam(df, dim_maxes, act_gt=0.75, n_classes=3, frac_above_threshold=0.8):
"""
Gets the top Pfam families of sequences with activations greater than a threshold.
For all sequences with activations greater than the threshold, it will return the top
n_classes Pfam families if they account for at least frac_above_threshold of the sequences.
Args:
df: DataFrame containing the sequences
dim_maxes: Numpy array of max activations for each sequence
act_gt: Threshold for activations
n_classes: Number of top Pfam families to return
frac_above_threshold: Fraction of sequences that must be accounted for by the top n_classes Pfam families
Returns:
List of top Pfam families
"""
normalized_acts = dim_maxes / dim_maxes.max()
df_dim = df.with_columns(pl.Series(normalized_acts).alias("act"))
non_zero = df_dim.filter(pl.col("act") > 0)
if len(non_zero) < 10:
return []

gt_50 = non_zero.filter(pl.col("act") > act_gt)
gt_50 = gt_50.with_columns(
pl.col("Pfam").str.strip_chars(";").str.split(";").alias("pfam_list")
)
exploded = gt_50.explode("pfam_list")
count_table = (
exploded["pfam_list"].value_counts().drop_nulls().sort("count", descending=True)
)
count_order = {value: i for i, value in enumerate(count_table["pfam_list"])}
exploded = (
exploded.with_columns(
pl.col("pfam_list").replace_strict(count_order).alias("pfam_ordered")
)
.sort("pfam_ordered")
.drop_nulls()
)
cleaned_df = exploded.unique(subset=["Entry"], maintain_order=True)
cleaned_df = cleaned_df.rename({"pfam_list": "pfam_common"})
keep = (
cleaned_df["pfam_common"]
.value_counts()
.drop_nulls()
.sort("count", descending=True)
.filter(pl.col("interpro_ids") != "")[:top_n]
)
freq_table = counts.with_columns(pl.col("count") / total_rows).rename(
{"count": "freq"}
)
freq_dict = freq_table.to_dict()
return {key: value.to_list() for key, value in freq_dict.items()}

if len(keep) >= 1:
top_count = sum(keep["count"][:n_classes])
if top_count > (len(gt_50) * frac_above_threshold):
return keep["pfam_common"][:n_classes].to_list()

return []


if __name__ == "__main__":
Expand Down

0 comments on commit bed6c17

Please sign in to comment.