Skip to content

Commit

Permalink
Merge pull request #96 from stanfordnlp/peterwz
Browse files Browse the repository at this point in the history
[P1] Adaptive changes with pyvene 101 colab
  • Loading branch information
frankaging authored Jan 28, 2024
2 parents d5beda8 + 5da8802 commit 1b261df
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 59 deletions.
59 changes: 28 additions & 31 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,35 +253,34 @@ def gather_neurons(tensor_input, unit, unit_locations_as_list):
return tensor_input

if "." in unit:
if unit in {"h.pos"}:
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
torch.tensor(unit_locations_as_list[1], device=tensor_input.device),
)
# we assume unit_locations is a tuple
head_unit_locations = unit_locations[0]
pos_unit_locations = unit_locations[1]

head_tensor_output = torch.gather(
tensor_input,
1,
head_unit_locations.reshape(
*head_unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2)
).expand(-1, -1, *tensor_input.shape[2:]),
) # b, h, s, d
d = head_tensor_output.shape[1]
pos_tensor_input = bhsd_to_bs_hd(head_tensor_output)
pos_tensor_output = torch.gather(
pos_tensor_input,
1,
pos_unit_locations.reshape(
*pos_unit_locations.shape, *(1,) * (len(pos_tensor_input.shape) - 2)
).expand(-1, -1, *pos_tensor_input.shape[2:]),
) # b, num_unit (pos), num_unit (h)*d
tensor_output = bs_hd_to_bhsd(pos_tensor_output, d)

return tensor_output # b, num_unit (h), num_unit (pos), d
elif unit in {"h", "pos"}:
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
torch.tensor(unit_locations_as_list[1], device=tensor_input.device),
)
# we assume unit_locations is a tuple
head_unit_locations = unit_locations[0]
pos_unit_locations = unit_locations[1]

head_tensor_output = torch.gather(
tensor_input,
1,
head_unit_locations.reshape(
*head_unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2)
).expand(-1, -1, *tensor_input.shape[2:]),
) # b, h, s, d
d = head_tensor_output.shape[1]
pos_tensor_input = bhsd_to_bs_hd(head_tensor_output)
pos_tensor_output = torch.gather(
pos_tensor_input,
1,
pos_unit_locations.reshape(
*pos_unit_locations.shape, *(1,) * (len(pos_tensor_input.shape) - 2)
).expand(-1, -1, *pos_tensor_input.shape[2:]),
) # b, num_unit (pos), num_unit (h)*d
tensor_output = bs_hd_to_bhsd(pos_tensor_output, d)

return tensor_output # b, num_unit (h), num_unit (pos), d
else:
unit_locations = torch.tensor(
unit_locations_as_list, device=tensor_input.device
)
Expand All @@ -294,8 +293,6 @@ def gather_neurons(tensor_input, unit, unit_locations_as_list):
)
return tensor_output

raise ValueError(f"Not Implemented Gathering with Unit = {unit}")


def scatter_neurons(
tensor_input,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ huggingface-hub==0.20.3
numpy==1.23.5
fsspec==2023.6.0
accelerate>=0.26.1
sentencepiece==0.1.96
29 changes: 1 addition & 28 deletions tests/unit_tests/ModelUtilsTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,14 @@ def test_gather_neurons_pos_h_positive(self):
)
self.assertTrue(torch.allclose(tensor_output, tensor_input[:, 1:3, 0:2, :]))

def _test_gather_neurons_negative(self, name, unit, expected_error_msg):
tensor_input = torch.rand((5, 3, 2))
with self.assertRaisesRegex(ValueError, expected_error_msg):
gather_neurons(tensor_input, unit, [[0, 1]] * 5)

def test_gather_neurons_negative(self):
self._test_gather_neurons_negative(
"dim",
"dim",
"Not Implemented Gathering with Unit = dim",
)
self._test_gather_neurons_negative(
"pos.dim",
"pos.dim",
"Not Implemented Gathering with Unit = pos.dim",
)
self._test_gather_neurons_negative(
"h.dim", "h.dim", "Not Implemented Gathering with Unit = h.dim"
)
self._test_gather_neurons_negative(
"h.pos.dim", "h.pos.dim", "Not Implemented Gathering with Unit = h.pos.dim"
)

def test_output_to_subcomponent_gpt2_no_head_positive(self):
# batch_size, seq_len, emb_dim
tensor_input = torch.rand((2, 5, 6))

golden_output = tensor_input.clone()

tensor_output = output_to_subcomponent(
tensor_input,
"attention_input",
self.gpt2_model,
self.gpt2_config,
tensor_input, "attention_input", self.gpt2_model, self.gpt2_config,
)
self.assertTrue(torch.allclose(tensor_output, golden_output))

Expand Down Expand Up @@ -357,7 +331,6 @@ def suite():
suite = unittest.TestSuite()
suite.addTest(ModelUtilsTestCase("test_gather_neurons_pos_h_positive"))
suite.addTest(ModelUtilsTestCase("test_gather_neurons_positive"))
suite.addTest(ModelUtilsTestCase("test_gather_neurons_negative"))
suite.addTest(ModelUtilsTestCase("test_scatter_gathered_neurons_gpt2_positive"))
suite.addTest(ModelUtilsTestCase("test_scatter_gathered_neurons_gpt2_qkv_positive"))
suite.addTest(
Expand Down

0 comments on commit 1b261df

Please sign in to comment.