Skip to content

Commit

Permalink
Merge pull request #81 from stanfordnlp/peterwz
Browse files Browse the repository at this point in the history
gather_neurons() unit tests
  • Loading branch information
frankaging authored Jan 22, 2024
2 parents aec9c42 + 7cc5f2d commit de246a1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 12 deletions.
10 changes: 7 additions & 3 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,12 @@ def bs_hd_to_bhsd(tensor, h):


def gather_neurons(tensor_input, unit, unit_locations_as_list):
"""Gather intervening neurons"""

"""
Gather intervening neurons.
:param tensor_input: tensors of shape (batch_size, sequence_length, ...) if `unit` is "pos" or "h", tensors of shape (batch_size, num_heads, sequence_length, ...) if `unit` is "h.pos"
:param unit: the intervention units to gather. Units could be "h" - head number, "pos" - position in the sequence, or "dim" - a particular dimension in the embedding space. If intervening multiple units, they are ordered and separated by `.`. Currently only support "pos", "h", and "h.pos" units.
:param unit_locations_as_list: tuple of lists of lists of positions to gather in tensor_input, according to the unit.
"""
if "." in unit:
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
Expand All @@ -268,7 +272,7 @@ def gather_neurons(tensor_input, unit, unit_locations_as_list):
1,
unit_locations.reshape(
*unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2)
).expand(-1, -1, *tensor_input.shape[2:]),
).expand(-1, -1, *tensor_input.shape[2:])
)

return tensor_output
Expand Down
18 changes: 9 additions & 9 deletions tests/integration_tests/InterventionWithGPT2TestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _test_with_position_intervention(
intervention_type,
positions=[0],
use_fast=False,
use_boardcast=False,
use_broadcast=False,
):
max_position = np.max(np.array(positions))
if isinstance(positions[0], list):
Expand Down Expand Up @@ -157,7 +157,7 @@ def _test_with_position_intervention(
self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]}
)

if use_boardcast:
if use_broadcast:
assert isinstance(positions[0], int)
_, out_output_1 = intervenable(
base, [source], {"sources->base": positions[0]}
Expand Down Expand Up @@ -361,7 +361,7 @@ def test_with_location_broadcast_vanilla_intervention_positive(self):
intervention_stream=stream,
intervention_type=VanillaIntervention,
positions=[random.randint(0, 3)],
use_boardcast=True,
use_broadcast=True,
)
print(f"testing broadcast with stream: {stream} with a single position (with fast)")
self._test_with_position_intervention(
Expand All @@ -370,7 +370,7 @@ def test_with_location_broadcast_vanilla_intervention_positive(self):
intervention_type=VanillaIntervention,
positions=[random.randint(0, 3)],
use_fast=True,
use_boardcast=True,
use_broadcast=True,
)

def _test_with_position_intervention_constant_source(
Expand All @@ -381,7 +381,7 @@ def _test_with_position_intervention_constant_source(
positions=[0],
use_base_only=False,
use_fast=False,
use_boardcast=False,
use_broadcast=False,
):
max_position = np.max(np.array(positions))
if isinstance(positions[0], list):
Expand Down Expand Up @@ -432,7 +432,7 @@ def _test_with_position_intervention_constant_source(
)

if use_base_only:
if use_boardcast:
if use_broadcast:
_, out_output = intervenable(
base,
unit_locations={"base": positions[0]},
Expand All @@ -443,7 +443,7 @@ def _test_with_position_intervention_constant_source(
unit_locations={"base": ([[positions] * b_s])},
)
else:
if use_boardcast:
if use_broadcast:
_, out_output = intervenable(
base,
unit_locations={"sources->base": (None, positions[0])},
Expand Down Expand Up @@ -483,15 +483,15 @@ def test_with_position_intervention_constant_source_vanilla_intervention_positiv
intervention_type=VanillaIntervention,
positions=[0],
use_base_only=True,
use_boardcast=True
use_broadcast=True
)
self._test_with_position_intervention_constant_source(
intervention_layer=random.randint(0, 3),
intervention_stream=stream,
intervention_type=VanillaIntervention,
positions=[0],
use_base_only=True,
use_boardcast=True,
use_broadcast=True,
use_fast=True
)

Expand Down
56 changes: 56 additions & 0 deletions tests/unit_tests/ModelUtilsTestCase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
from ..utils import *
from pyvene.models.modeling_utils import *

class ModelUtilsTestCase(unittest.TestCase):
@classmethod
def setUpClass(self):
pass

def test_gather_neurons_positive(self):
tensor_input = torch.rand((5, 3, 2)) # batch_size, seq_len, emb_dim
tensor_output = gather_neurons(tensor_input, "pos", [[0,1]] * 5)
self.assertTrue(torch.allclose(tensor_output, tensor_input[:, 0:2, :]))
tensor_output = gather_neurons(tensor_input, "h", [[0,1]] * 5)
self.assertTrue(torch.allclose(tensor_output, tensor_input[:, 0:2, :]))

def test_gather_neurons_pos_h_positive(self):
tensor_input = torch.rand((5, 4, 3, 2)) # batch_size, #heads, seq_len, emb_dim
tensor_output = gather_neurons(tensor_input, "h.pos", ([[1,2]] * 5, [[0,1]] * 5))
self.assertTrue(torch.allclose(tensor_output, tensor_input[:, 1:3, 0:2, :]))

def _test_gather_neurons_negative(self, name, unit, expected_error_msg):
tensor_input = torch.rand((5, 3, 2))
with self.assertRaisesRegex(AssertionError, expected_error_msg):
gather_neurons(tensor_input, unit, [[0,1]] * 5)

def test_gather_neurons_negative(self):
self._test_gather_neurons_negative("dim", "dim", "Not Implemented Gathering with Unit = dim",)
self._test_gather_neurons_negative("pos.dim", "pos.dim", "Not Implemented Gathering with Unit = pos.dim",)
self._test_gather_neurons_negative("h.dim", "h.dim", "Not Implemented Gathering with Unit = h.dim")
self._test_gather_neurons_negative("h.pos.dim", "h.pos.dim", "Not Implemented Gathering with Unit = h.pos.dim")


def suite():
suite = unittest.TestSuite()
suite.addTest(
ModelUtilsTestCase(
"test_gather_neurons_pos_h_positive"
)
)
suite.addTest(
ModelUtilsTestCase(
"test_gather_neurons_positive"
)
)
suite.addTest(
ModelUtilsTestCase(
"test_gather_neurons_negative"
)
)
return suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())

0 comments on commit de246a1

Please sign in to comment.