diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 9cda3f9dd..14d6577d9 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1679,9 +1679,9 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): offsets (Optional[torch.Tensor]): jagged slices, represented as cumulative offsets. stride (Optional[int]): number of examples per batch. - stride_per_key_per_rank (Optional[List[List[int]]]): batch size - (number of examples) per key per rank, with the outer list representing the - keys and the inner list representing the values. + stride_per_key_per_rank (Optional[Union[torch.IntTensor, List[List[int]]]]): + batch size (number of examples) per key per rank, with the outer list + representing the keys and the inner list representing the values. Each value in the inner list represents the number of examples in the batch from the rank of its index in a distributed context. length_per_key (Optional[List[int]]): start length for each key.