diff --git a/distributed_shampoo/utils/tests/shampoo_distributor_test.py b/distributed_shampoo/utils/tests/shampoo_distributor_test.py index ff23203..dc720e0 100644 --- a/distributed_shampoo/utils/tests/shampoo_distributor_test.py +++ b/distributed_shampoo/utils/tests/shampoo_distributor_test.py @@ -68,8 +68,8 @@ def _get_distributor_type(self) -> Type[DistributorInterface]: def test_update_params(self) -> None: # Explicitly disable the gradient of the bias layer and call merge_and_block_gradients() # to update the local gradient selector. - self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) - self._model.linear_layers[0].bias.grad = None + self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) # type: ignore[index, union-attr] + self._model.linear_layers[0].bias.grad = None # type: ignore[index, union-attr] self._distributor.merge_and_block_gradients() actual_masked_blocked_params = self._distributor.local_masked_blocked_params @@ -93,8 +93,8 @@ def test_update_params(self) -> None: def test_local_grad_selector(self) -> None: # Explicitly disable the gradient of the bias layer and call merge_and_block_gradients() # to update the local gradient selector for the bias layer (i.e., 3rd block). - self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) - self._model.linear_layers[0].bias.grad = None + self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) # type: ignore[index, union-attr] + self._model.linear_layers[0].bias.grad = None # type: ignore[index, union-attr] self._distributor.merge_and_block_gradients() expected_local_grad_selector = (True, True, False) @@ -119,15 +119,15 @@ def test_local_blocked_params(self) -> None: def test_local_block_info_list(self) -> None: expected_local_block_info_list = ( BlockInfo( - param=self._model.linear_layers[0].weight, + param=self._model.linear_layers[0].weight, # type: ignore[index, union-attr] composable_block_ids=(0, "block_0"), ), BlockInfo( - param=self._model.linear_layers[0].weight, + param=self._model.linear_layers[0].weight, # type: ignore[index, union-attr] composable_block_ids=(0, "block_1"), ), BlockInfo( - param=self._model.linear_layers[0].bias, + param=self._model.linear_layers[0].bias, # type: ignore[index, union-attr] composable_block_ids=(1, "block_0"), ), ) @@ -137,8 +137,8 @@ def test_local_block_info_list(self) -> None: ) def test_merge_and_block_gradients(self) -> None: - self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) - self._model.linear_layers[0].bias.grad = None + self._model.linear_layers[0].weight.grad = torch.ones((5, 10)) # type: ignore[index, union-attr] + self._model.linear_layers[0].bias.grad = None # type: ignore[index, union-attr] actual_local_masked_block_grads = self._distributor.merge_and_block_gradients() expected_local_masked_block_grads = ( torch.ones((5, 5)),