From bed6c17643ece2454747be14143786272d7ec85e Mon Sep 17 00:00:00 2001 From: etowahadams Date: Wed, 27 Nov 2024 21:43:13 -0500 Subject: [PATCH] New viz file with ranges (#28) * fix: sequence * family info * fix: uniprot id * fix: max --- interprot/make_viz_files/__main__.py | 213 ++++++++++++++++----------- 1 file changed, 131 insertions(+), 82 deletions(-) diff --git a/interprot/make_viz_files/__main__.py b/interprot/make_viz_files/__main__.py index b0560ee..41401b0 100644 --- a/interprot/make_viz_files/__main__.py +++ b/interprot/make_viz_files/__main__.py @@ -81,11 +81,7 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str): ) df = pl.read_parquet(sequences_file) - has_interpro = "InterPro" in df.columns - if has_interpro: - df = df.with_columns( - pl.col("InterPro").str.split(";").alias("interpro_ids") - ) + has_pfam = "Pfam" in df.columns # Pre-allocate numpy array for storing max activations all_seqs_max_act = np.zeros((sae_dim, len(df))) @@ -107,6 +103,7 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str): sae_acts_cpu = sae_acts.cpu().numpy() all_seqs_max_act[:, seq_idx] = np.max(sae_acts_cpu, axis=0) sae_acts_int = (sae_acts_cpu * 10).astype(np.uint8) + # Convert to sparse matrix. This significantly reduces memory usage sparse_acts = sparse.csr_matrix(sae_acts_int) all_acts[seq_idx] = sparse_acts # Clear CUDA cache periodically @@ -119,97 +116,149 @@ def make_viz_files(checkpoint_files: list[str], sequences_file: str): hidden_dim_to_seqs = {dim: {} for dim in range(sae_dim)} - # Calculate the top sequences for each hidden dimension and quartile - quartile_names = ["Q1", "Q2", "Q3", "Q4"] + act_ranges = [[0, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1]] + range_names = [f"{start}-{end}" for start, end in act_ranges] + for dim in tqdm( range(sae_dim), desc="Finding highest activating seqs (Step 2/3)" ): dim_maxes = all_seqs_max_act[dim] - if np.all(dim_maxes == 0): + non_zero_maxes = dim_maxes[dim_maxes > 0] + + if len(non_zero_maxes) == 0: + print(f"Skipping dimension {dim} as it has no activations") continue + + # Get top Pfam families for sequences with activations greater than 0.75 + if has_pfam: + top_families = get_top_pfam( + df, dim_maxes, act_gt=0.75, n_classes=3, frac_above_threshold=0.8 + ) + hidden_dim_to_seqs[dim]["top_pfam"] = top_families + non_zero_maxes = dim_maxes[dim_maxes > 0] - hidden_dim_to_seqs[dim]["freq_activate_among_all_seqs"] = len( - non_zero_maxes - ) / len(dim_maxes) - quartiles = np.percentile(non_zero_maxes, [25, 50, 75]) - - q1_mask = (dim_maxes > 0) & (dim_maxes <= quartiles[0]) - q2_mask = (dim_maxes > quartiles[0]) & (dim_maxes <= quartiles[1]) - q3_mask = (dim_maxes > quartiles[1]) & (dim_maxes <= quartiles[2]) - q4_mask = dim_maxes > quartiles[2] - - quartile_indices = [ - np.where(mask)[0] for mask in [q1_mask, q2_mask, q3_mask, q4_mask] - ] - for q_name, q_indices in zip(quartile_names, quartile_indices): + hidden_dim_to_seqs[dim]["freq_active"] = len(non_zero_maxes) / len( + dim_maxes + ) + hidden_dim_to_seqs[dim]["n_seqs"] = len(non_zero_maxes) + hidden_dim_to_seqs[dim]["max_act"] = float(dim_maxes.max()) + + normalized_acts = dim_maxes / dim_maxes.max() + for i, (start, end) in enumerate(act_ranges): + mask = (normalized_acts > start) & (normalized_acts <= end) top_indices = heapq.nlargest( - NUM_SEQS_PER_DIM, q_indices, key=lambda i: dim_maxes[i] + NUM_SEQS_PER_DIM, np.where(mask)[0], key=lambda i: dim_maxes[i] ) - hidden_dim_to_seqs[dim][q_name] = {} - hidden_dim_to_seqs[dim][q_name]["n_seqs"] = len(q_indices) - hidden_dim_to_seqs[dim][q_name]["indices"] = top_indices - if has_interpro: - hidden_dim_to_seqs[dim][q_name]["interpro"] = get_top_interpro( - df, q_indices, top_n=10 - ) + range_name = range_names[i] + hidden_dim_to_seqs[dim][range_name] = {} + hidden_dim_to_seqs[dim][range_name]["indices"] = top_indices for dim in tqdm(range(sae_dim), desc="Writing visualization files (Step 3/3)"): - viz_file = {"quartiles": {}} - if "freq_activate_among_all_seqs" in hidden_dim_to_seqs[dim]: - viz_file["freq_activate_among_all_seqs"] = hidden_dim_to_seqs[dim][ - "freq_activate_among_all_seqs" - ] - for quartile in quartile_names: - if quartile not in hidden_dim_to_seqs[dim]: - continue - quartile_examples = { - "examples": [], - "n_seqs": hidden_dim_to_seqs[dim][quartile]["n_seqs"], - } - if has_interpro: - quartile_examples["interpro"] = hidden_dim_to_seqs[dim][quartile][ - "interpro" - ] - quartile_indices = hidden_dim_to_seqs[dim][quartile]["indices"] - - for seq_idx in quartile_indices: - seq_idx = int(seq_idx) - sae_acts = all_acts[seq_idx].toarray() - dim_acts = sae_acts[:, dim] - uniprot_id = df[seq_idx]["Entry"].item()[:-1] - alphafolddb_id = df[seq_idx]["AlphaFoldDB"].item().split(";")[0] - protein_name = df[seq_idx]["Protein names"].item() - - examples = { - "sae_acts": [round(float(act) / 10, 1) for act in dim_acts], - "sequence": seq, - "alphafold_id": alphafolddb_id, - "uniprot_id": uniprot_id, - "name": protein_name, - } - quartile_examples["examples"].append(examples) - - viz_file["quartiles"][quartile] = quartile_examples - - with open(os.path.join(OUTPUT_ROOT_DIR, f"{dim}.json"), "w") as f: - json.dump(viz_file, f) - - -def get_top_interpro(original_df, indices, top_n=5): - df = original_df[indices] - total_rows = len(df) - counts = ( - df.explode("interpro_ids")["interpro_ids"] + if not hidden_dim_to_seqs[dim]: + print(f"Skipping dimension {dim} as it has no sequences") + continue + write_viz_file(hidden_dim_to_seqs[dim], dim, all_acts, df, range_names) + + +def write_viz_file(dim_info, dim, all_acts, df, range_names): + viz_file = {"ranges": {}} + # Write how common the dimension is + if "freq_active" in dim_info: + viz_file["freq_active"] = dim_info["freq_active"] + if "n_seqs" in dim_info: + viz_file["n_seqs"] = dim_info["n_seqs"] + if "top_pfam" in dim_info: + viz_file["top_pfam"] = dim_info["top_pfam"] + if "max_act" in dim_info: + viz_file["max_act"] = dim_info["max_act"] + + for range_name in range_names: + if range_name not in dim_info: + continue + range_examples = { + "examples": [], + } + top_indices = dim_info[range_name]["indices"] + + for seq_idx in top_indices: + seq_idx = int(seq_idx) + sae_acts = all_acts[seq_idx].toarray() + dim_acts = sae_acts[:, dim] + uniprot_id = df[seq_idx]["Entry"].item() + alphafolddb_id = df[seq_idx]["AlphaFoldDB"].item().split(";")[0] + protein_name = df[seq_idx]["Protein names"].item() + sequence = df[seq_idx]["Sequence"].item() + + examples = { + "sae_acts": [round(float(act) / 10, 1) for act in dim_acts], + "sequence": sequence, + "alphafold_id": alphafolddb_id, + "uniprot_id": uniprot_id, + "name": protein_name, + } + range_examples["examples"].append(examples) + + viz_file["ranges"][range_name] = range_examples + + with open(os.path.join(OUTPUT_ROOT_DIR, f"{dim}.json"), "w") as f: + json.dump(viz_file, f) + + + +def get_top_pfam(df, dim_maxes, act_gt=0.75, n_classes=3, frac_above_threshold=0.8): + """ + Gets the top Pfam families of sequences with activations greater than a threshold. + For all sequences with activations greater than the threshold, it will return the top + n_classes Pfam families if they account for at least frac_above_threshold of the sequences. + + Args: + df: DataFrame containing the sequences + dim_maxes: Numpy array of max activations for each sequence + act_gt: Threshold for activations + n_classes: Number of top Pfam families to return + frac_above_threshold: Fraction of sequences that must be accounted for by the top n_classes Pfam families + + Returns: + List of top Pfam families + + """ + normalized_acts = dim_maxes / dim_maxes.max() + df_dim = df.with_columns(pl.Series(normalized_acts).alias("act")) + non_zero = df_dim.filter(pl.col("act") > 0) + if len(non_zero) < 10: + return [] + + gt_50 = non_zero.filter(pl.col("act") > act_gt) + gt_50 = gt_50.with_columns( + pl.col("Pfam").str.strip_chars(";").str.split(";").alias("pfam_list") + ) + exploded = gt_50.explode("pfam_list") + count_table = ( + exploded["pfam_list"].value_counts().drop_nulls().sort("count", descending=True) + ) + count_order = {value: i for i, value in enumerate(count_table["pfam_list"])} + exploded = ( + exploded.with_columns( + pl.col("pfam_list").replace_strict(count_order).alias("pfam_ordered") + ) + .sort("pfam_ordered") .drop_nulls() + ) + cleaned_df = exploded.unique(subset=["Entry"], maintain_order=True) + cleaned_df = cleaned_df.rename({"pfam_list": "pfam_common"}) + keep = ( + cleaned_df["pfam_common"] .value_counts() + .drop_nulls() .sort("count", descending=True) - .filter(pl.col("interpro_ids") != "")[:top_n] - ) - freq_table = counts.with_columns(pl.col("count") / total_rows).rename( - {"count": "freq"} ) - freq_dict = freq_table.to_dict() - return {key: value.to_list() for key, value in freq_dict.items()} + + if len(keep) >= 1: + top_count = sum(keep["count"][:n_classes]) + if top_count > (len(gt_50) * frac_above_threshold): + return keep["pfam_common"][:n_classes].to_list() + + return [] if __name__ == "__main__":