diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py index 4be370e6a..9bb988d55 100644 --- a/torchrec/optim/tests/test_keyed.py +++ b/torchrec/optim/tests/test_keyed.py @@ -81,7 +81,6 @@ def test_load_state_dict(self) -> None: "one": 1.0, "tensor": torch.tensor([5.0, 6.0]), "sharded_tensor": sharded_tensor.full( - # pyre-ignore [28] sharding_spec.ChunkShardingSpec( dim=0, placements=["rank:0/cpu"] ), @@ -116,7 +115,6 @@ def test_load_state_dict(self) -> None: "one": 1.0, "tensor": torch.tensor([5.0, 6.0]), "sharded_tensor": sharded_tensor.full( - # pyre-ignore [28] sharding_spec.ChunkShardingSpec(dim=0, placements=["rank:0/cpu"]), (4,), fill_value=1.0, @@ -157,7 +155,6 @@ def test_load_state_dict(self) -> None: expected_state_dict["state"]["param_1"]["tensor"] = torch.tensor([50.0, 60.0]) # pyre-ignore [6] expected_state_dict["state"]["param_1"]["sharded_tensor"] = sharded_tensor.full( - # pyre-ignore [28] sharding_spec.ChunkShardingSpec(dim=0, placements=["rank:0/cpu"]), (4,), fill_value=10.0,