From 6ee666610d5770187a60c75d843e77fa722f7d58 Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 17 Oct 2024 20:31:25 -0400 Subject: [PATCH] Test and fix --- .../logistic_regression_probe/utils.py | 10 +-- tests/logistic_regression_probe/test_utils.py | 64 +++++++++++++++++++ 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/plm_interpretability/logistic_regression_probe/utils.py b/plm_interpretability/logistic_regression_probe/utils.py index 2537a68..3bffae7 100644 --- a/plm_interpretability/logistic_regression_probe/utils.py +++ b/plm_interpretability/logistic_regression_probe/utils.py @@ -75,8 +75,8 @@ def extract_sequences( def train_test_split_by_homology( sequences: list[str], max_seqs: int, - test_ratio: float = 0.1, - similarity_threshold: float = 0.4, + test_ratio: float = 0.2, + similarity_threshold: float = 0.3, ) -> tuple[set[str], set[str]]: """ Given a list of sequences and a max_seqs cutoff: @@ -240,7 +240,7 @@ def make_examples_from_annotation_entries( # Sample 1-2 random annotations with the same length as the positive annotation # that don't overlap with the positive annotation as negative examples. - if start > annotation_length: + if start >= annotation_length: random_start_on_left = random.randint(0, start - annotation_length) random_end_on_left = random_start_on_left + annotation_length examples.append( @@ -251,8 +251,8 @@ def make_examples_from_annotation_entries( target=False, ) ) - if end < len(sae_acts) - annotation_length: - random_start_on_right = random.randint(end, len(sae_acts) - annotation_length) + if end < len(seq) - annotation_length: + random_start_on_right = random.randint(end, len(seq) - annotation_length) random_end_on_right = random_start_on_right + annotation_length examples.append( Example( diff --git a/tests/logistic_regression_probe/test_utils.py b/tests/logistic_regression_probe/test_utils.py index 07c7003..2c1b816 100644 --- a/tests/logistic_regression_probe/test_utils.py +++ b/tests/logistic_regression_probe/test_utils.py @@ -86,6 +86,70 @@ def test_make_examples_from_annotation_entries(self, mock_get_sae_acts): plm_layer=24, ) + @patch("plm_interpretability.logistic_regression_probe.utils.get_sae_acts") + def test_make_examples_from_annotation_entries_pool_over_annotation(self, mock_get_sae_acts): + seq_to_annotation_entries = { + "AAAAAAAAAA": [{"start": 4, "end": 6}], + "CCCCCCCCCC": [{"start": 1, "end": 3}, {"start": 5, "end": 6}], + } + + mock_tokenizer = Mock() + mock_plm_model = Mock() + mock_sae_model = Mock() + + mock_get_sae_acts.side_effect = [ + [ + [0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6], + [0.7, 0.8], + [0.9, 1.0], + [1.1, 1.2], + [1.3, 1.4], + [1.5, 1.6], + [1.7, 1.8], + [1.9, 2.0], + ], + [ + [2.1, 2.2], + [2.3, 2.4], + [2.5, 2.6], + [2.7, 2.8], + [2.9, 3.0], + [3.1, 3.2], + [3.3, 3.4], + [3.5, 3.6], + [3.7, 3.8], + [3.9, 4.0], + ], + ] + + examples = make_examples_from_annotation_entries( + seq_to_annotation_entries, + mock_tokenizer, + mock_plm_model, + mock_sae_model, + plm_layer=24, + pool_over_annotation=True, + ) + print(examples) + + self.assertEqual(len(examples), 8) + + self.assertIn( + Example(sae_acts=np.mean([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], axis=0), target=True), + examples, + ) + self.assertIn( + Example(sae_acts=np.mean([[2.1, 2.2], [2.3, 2.4], [2.5, 2.6]], axis=0), target=True), + examples, + ) + self.assertIn( + Example(sae_acts=np.mean([[2.9, 3.0], [3.1, 3.2]], axis=0), target=True), + examples, + ) + self.assertEqual(len([e for e in examples if e.target is False]), 5) + def test_get_annotation_entries_for_class(self): mock_df = pd.DataFrame( {