Skip to content

Commit

Permalink
Ignore OSS mypy false-positive errors on __getitem__ of `nn.Sequent…
Browse files Browse the repository at this point in the history
…ial`

Summary: This diff tries to ignore those false-positive errors from mypy.

Differential Revision: D69360427
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Feb 9, 2025
1 parent 9c5700a commit ad6c4b9
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions distributed_shampoo/utils/tests/shampoo_distributor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def _get_distributor_type(self) -> Type[DistributorInterface]:
def test_update_params(self) -> None:
# Explicitly disable the gradient of the bias layer and call merge_and_block_gradients()
# to update the local gradient selector.
self._model.linear_layers[0].weight.grad = torch.ones((5, 10))
self._model.linear_layers[0].bias.grad = None
(
self._model.linear_layers[0].weight.grad,
self._model.linear_layers[0].bias.grad,
) = torch.ones((5, 10)), None # type: ignore[index, union-attr]
self._distributor.merge_and_block_gradients()

actual_masked_blocked_params = self._distributor.local_masked_blocked_params
Expand All @@ -93,8 +95,10 @@ def test_update_params(self) -> None:
def test_local_grad_selector(self) -> None:
# Explicitly disable the gradient of the bias layer and call merge_and_block_gradients()
# to update the local gradient selector for the bias layer (i.e., 3rd block).
self._model.linear_layers[0].weight.grad = torch.ones((5, 10))
self._model.linear_layers[0].bias.grad = None
(
self._model.linear_layers[0].weight.grad,
self._model.linear_layers[0].bias.grad,
) = torch.ones((5, 10)), None # type: ignore[index, union-attr]
self._distributor.merge_and_block_gradients()

expected_local_grad_selector = (True, True, False)
Expand All @@ -119,15 +123,15 @@ def test_local_blocked_params(self) -> None:
def test_local_block_info_list(self) -> None:
expected_local_block_info_list = (
BlockInfo(
param=self._model.linear_layers[0].weight,
param=self._model.linear_layers[0].weight, # type: ignore[index, union-attr]
composable_block_ids=(0, "block_0"),
),
BlockInfo(
param=self._model.linear_layers[0].weight,
param=self._model.linear_layers[0].weight, # type: ignore[index, union-attr]
composable_block_ids=(0, "block_1"),
),
BlockInfo(
param=self._model.linear_layers[0].bias,
param=self._model.linear_layers[0].bias, # type: ignore[index, union-attr]
composable_block_ids=(1, "block_0"),
),
)
Expand All @@ -137,8 +141,10 @@ def test_local_block_info_list(self) -> None:
)

def test_merge_and_block_gradients(self) -> None:
self._model.linear_layers[0].weight.grad = torch.ones((5, 10))
self._model.linear_layers[0].bias.grad = None
(
self._model.linear_layers[0].weight.grad,
self._model.linear_layers[0].bias.grad,
) = torch.ones((5, 10)), None # type: ignore[index, union-attr]
actual_local_masked_block_grads = self._distributor.merge_and_block_gradients()
expected_local_masked_block_grads = (
torch.ones((5, 5)),
Expand Down

0 comments on commit ad6c4b9

Please sign in to comment.