From ec78ba9fa0b2039f9f97e6d1574b98b3b50a0ea0 Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 15 Oct 2024 15:14:20 -0400 Subject: [PATCH] Trim --- plm_interpretability/latent_probe/__main__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/plm_interpretability/latent_probe/__main__.py b/plm_interpretability/latent_probe/__main__.py index b0476e9..2acbd40 100644 --- a/plm_interpretability/latent_probe/__main__.py +++ b/plm_interpretability/latent_probe/__main__.py @@ -108,7 +108,6 @@ class ResidueAnnotation: ] -# @functools.lru_cache(maxsize=5000) def get_sae_acts( seq: str, tokenizer: AutoTokenizer, @@ -154,14 +153,13 @@ def get_annotation_entries_for_class( # The note field is sometimes like "Homeobox", "Homeobox 1", etc., # so use string `in` to check. entries = [e for e in entries if class_name in e.get("note", "")] - if len(entries) > 0: + if len(entries) > 0 and len(seq) < 2000: seq_to_annotation_entries[seq] = entries seq_lengths.append(len(seq)) logger.info( f"Found {len(seq_to_annotation_entries)} sequences with class {class_name}." - f"Sequence length min: {min(seq_lengths)}, max: {max(seq_lengths)}, " - f"mean: {np.mean(seq_lengths)}." + f"Mean sequence length: {np.mean(seq_lengths):.2f}." ) if len(seq_to_annotation_entries) > max_seqs_per_task: