Skip to content

Commit

Permalink
Test and fix
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Oct 18, 2024
1 parent 74c4d91 commit 6ee6666
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
10 changes: 5 additions & 5 deletions plm_interpretability/logistic_regression_probe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def extract_sequences(
def train_test_split_by_homology(
sequences: list[str],
max_seqs: int,
test_ratio: float = 0.1,
similarity_threshold: float = 0.4,
test_ratio: float = 0.2,
similarity_threshold: float = 0.3,
) -> tuple[set[str], set[str]]:
"""
Given a list of sequences and a max_seqs cutoff:
Expand Down Expand Up @@ -240,7 +240,7 @@ def make_examples_from_annotation_entries(

# Sample 1-2 random annotations with the same length as the positive annotation
# that don't overlap with the positive annotation as negative examples.
if start > annotation_length:
if start >= annotation_length:
random_start_on_left = random.randint(0, start - annotation_length)
random_end_on_left = random_start_on_left + annotation_length
examples.append(
Expand All @@ -251,8 +251,8 @@ def make_examples_from_annotation_entries(
target=False,
)
)
if end < len(sae_acts) - annotation_length:
random_start_on_right = random.randint(end, len(sae_acts) - annotation_length)
if end < len(seq) - annotation_length:
random_start_on_right = random.randint(end, len(seq) - annotation_length)
random_end_on_right = random_start_on_right + annotation_length
examples.append(
Example(
Expand Down
64 changes: 64 additions & 0 deletions tests/logistic_regression_probe/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,70 @@ def test_make_examples_from_annotation_entries(self, mock_get_sae_acts):
plm_layer=24,
)

@patch("plm_interpretability.logistic_regression_probe.utils.get_sae_acts")
def test_make_examples_from_annotation_entries_pool_over_annotation(self, mock_get_sae_acts):
seq_to_annotation_entries = {
"AAAAAAAAAA": [{"start": 4, "end": 6}],
"CCCCCCCCCC": [{"start": 1, "end": 3}, {"start": 5, "end": 6}],
}

mock_tokenizer = Mock()
mock_plm_model = Mock()
mock_sae_model = Mock()

mock_get_sae_acts.side_effect = [
[
[0.1, 0.2],
[0.3, 0.4],
[0.5, 0.6],
[0.7, 0.8],
[0.9, 1.0],
[1.1, 1.2],
[1.3, 1.4],
[1.5, 1.6],
[1.7, 1.8],
[1.9, 2.0],
],
[
[2.1, 2.2],
[2.3, 2.4],
[2.5, 2.6],
[2.7, 2.8],
[2.9, 3.0],
[3.1, 3.2],
[3.3, 3.4],
[3.5, 3.6],
[3.7, 3.8],
[3.9, 4.0],
],
]

examples = make_examples_from_annotation_entries(
seq_to_annotation_entries,
mock_tokenizer,
mock_plm_model,
mock_sae_model,
plm_layer=24,
pool_over_annotation=True,
)
print(examples)

self.assertEqual(len(examples), 8)

self.assertIn(
Example(sae_acts=np.mean([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], axis=0), target=True),
examples,
)
self.assertIn(
Example(sae_acts=np.mean([[2.1, 2.2], [2.3, 2.4], [2.5, 2.6]], axis=0), target=True),
examples,
)
self.assertIn(
Example(sae_acts=np.mean([[2.9, 3.0], [3.1, 3.2]], axis=0), target=True),
examples,
)
self.assertEqual(len([e for e in examples if e.target is False]), 5)

def test_get_annotation_entries_for_class(self):
mock_df = pd.DataFrame(
{
Expand Down

0 comments on commit 6ee6666

Please sign in to comment.