Skip to content

Commit

Permalink
Ignore OSS mypy false-positive errors on __getitem__ of `nn.Sequent…
Browse files Browse the repository at this point in the history
…ial` (#82)

Summary:
Pull Request resolved: #82

This diff tries to ignore those false-positive errors from mypy.

Reviewed By: hjmshi, runame

Differential Revision: D69360427

fbshipit-source-id: 3db3da64e3d83ee3a2309da0b052cfef8a27c2e9
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Feb 26, 2025
1 parent 4bd0f72 commit d42d120
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions distributed_shampoo/utils/tests/shampoo_distributor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def _get_distributor_type(self) -> Type[DistributorInterface]:
def test_update_params(self) -> None:
# Explicitly disable the gradient of the bias layer and call merge_and_block_gradients()
# to update the local gradient selector.
self._model.linear_layers[0].weight.grad = torch.ones((5, 10))
self._model.linear_layers[0].bias.grad = None
self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) # type: ignore[index, union-attr]
self._model.linear_layers[0].bias.grad = None # type: ignore[index, union-attr]
self._distributor.merge_and_block_gradients()

actual_masked_blocked_params = self._distributor.local_masked_blocked_params
Expand All @@ -93,8 +93,8 @@ def test_update_params(self) -> None:
def test_local_grad_selector(self) -> None:
# Explicitly disable the gradient of the bias layer and call merge_and_block_gradients()
# to update the local gradient selector for the bias layer (i.e., 3rd block).
self._model.linear_layers[0].weight.grad = torch.ones((5, 10))
self._model.linear_layers[0].bias.grad = None
self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) # type: ignore[index, union-attr]
self._model.linear_layers[0].bias.grad = None # type: ignore[index, union-attr]
self._distributor.merge_and_block_gradients()

expected_local_grad_selector = (True, True, False)
Expand All @@ -119,15 +119,15 @@ def test_local_blocked_params(self) -> None:
def test_local_block_info_list(self) -> None:
expected_local_block_info_list = (
BlockInfo(
param=self._model.linear_layers[0].weight,
param=self._model.linear_layers[0].weight, # type: ignore[index, union-attr]
composable_block_ids=(0, "block_0"),
),
BlockInfo(
param=self._model.linear_layers[0].weight,
param=self._model.linear_layers[0].weight, # type: ignore[index, union-attr]
composable_block_ids=(0, "block_1"),
),
BlockInfo(
param=self._model.linear_layers[0].bias,
param=self._model.linear_layers[0].bias, # type: ignore[index, union-attr]
composable_block_ids=(1, "block_0"),
),
)
Expand All @@ -137,8 +137,8 @@ def test_local_block_info_list(self) -> None:
)

def test_merge_and_block_gradients(self) -> None:
self._model.linear_layers[0].weight.grad = torch.ones((5, 10))
self._model.linear_layers[0].bias.grad = None
self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) # type: ignore[index, union-attr]
self._model.linear_layers[0].bias.grad = None # type: ignore[index, union-attr]
actual_local_masked_block_grads = self._distributor.merge_and_block_gradients()
expected_local_masked_block_grads = (
torch.ones((5, 5)),
Expand Down

0 comments on commit d42d120

Please sign in to comment.