Skip to content

Conversation

@kaiyuan-li
Copy link
Contributor

Before this change, when fetching a tensor, we always fetch the whole tensor that's stored on each volume, e.g.
Full Tensor: [[0, 1], [2, 3]]
Storage Volume 0: [[0, 1]]
Storage Volume 1: [[2, 3]]

When we want to get a tensor slice of shape [2, 1] and offset [0, 0], the get() method first fetch the full tensor [[0, 1], [2, 3]] then extract the first column of it. This is not efficient.

After this change, we only fetch [[0]] from volume 0 and [[2]] from volume 1 and assemble them together.

Added a example/dtensor.py to help development. We can remove that later.

@kaiyuan-li kaiyuan-li requested a review from LucasLLC September 19, 2025 16:22
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 19, 2025
) -> torch.Tensor:
"""Fetches slices from all volume storages and stitch together to return the whole tensor"""

# dtensor_slice = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a switch for me to test the old (dtensor_slice=None) and new code path during development.


return assembled_tensor

def _compute_slice_intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been putting these un utils. We can consider creating a dtensor_utils file

stored_tensor, stored_slice, request.tensor_slice
)

if extracted_tensor is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean for extracted tensor to be None here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workflow has 2 steps

  1. in client, we try to find the overlapped tensor slice
  2. we send the overlapped tensor slice to volume. So when volume is fetching, it uses the slice info passed in from client. Theoretically, client the passed in slice MUST have an overlap with what's stored in the volume. If there's no overlap, the extracted_tensor will be None here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise an error in this case? It sounds like at this point the code expects the shard to exist

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, will address in follow up to request all tensors inside the whole volume.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm! Tysm!

tensor_slice, dtensor_slice
)

if tensor_slice is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me. Another thing to consider -- if we've already fetched the entire tensor slice region we can avoid doing so. DTensor is often replicated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll create an issue for this and follow up. Probably need a discussion with you on how to create a test case with replicated dtensor first.

assert device_mesh_shape == tensor_slice.mesh_shape

return assemble_global_tensor(
assembled_tensor = assemble_global_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this safe because we never access regions of the tensor which are not initialized? Is there any danger here of us returning uninitialized memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a more information on this? I'm not sure what you mean by "access regions of the tensor which are not initialized"? I thought if we are able to access the volume info, that means the tensor is already properly initialized?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In assemble global tensor, the first thing we do is create a 'torch.empty' tensor of the correct global size --

global_tensor = torch.empty(

Now that I'm thinking about it, for particularly large tensor types the behavior here is somewhat unwanted as well. Ideally we create a tensor of the correct local size, and correct for the offsets. devmate may be able to help with the mapping logic, can you give it a shot?


async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor:
"""Fetches slices from all volume storages and stitch together to return the whole tensor"""
async def _get_distributed_whole_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename this function now that it's not technically fetching the whole_tensor ?

f"Tensor slice {request.tensor_slice} not found in any stored shards for {key}"
)

def _extract_tensor_subset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this to be correct in the current implementation, I believe it must always return None if the entire requested_slice is not present.

await transport_buffer.write_from(shard["tensor"])
stored_slice = shard["slice"]
stored_tensor = shard["tensor"]

Copy link
Contributor

@LucasLLC LucasLLC Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What might be nice here is to have an if statement like: requested_tensor_slice is in shard

if not request_slice in shard:
      continue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verification added to early return if there's no intersection or the intersection is not exactly the tensor slice in the requestion.

@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why _ext for filename?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for "extended". Since running the whole test_sharding is too much (>30 tests) and CI will fail. So I picked some sharding tests (~20) to be "test_sharding_basic" and rest of them (more complete) to be "test_sharding_ext".

TORCHSTORE_RDMA_ENABLED=0 \
pytest tests/test_tensor_slice.py \
--cov=. --cov-report=xml -vv -s
- name: Run test_resharding_basic tests with coverage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty!


if stored_object_type is ObjectType.TENSOR:
full_tensor = await self._get_tensor(key)
# TODO: we should get the part of interest in this branch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this to me? Or rather confirm my understanding is correct:

If the stored object is a tensor, then we always fetch the entire tensor and then slice for the requested spec since the tensor can only be in one storage volume?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your description is accurate.

intersection_slice is None
or intersection_slice.local_shape != request.tensor_slice.local_shape
or intersection_slice.offsets != request.tensor_slice.offsets
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, we only return here if we have complete overlap between requested tensor slice and storage volume slice?

I think this is reasonable since the client has knowledge in advanced of whats being requested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, we only return here if we have complete overlap between requested tensor slice and storage volume slice?
Right, to clarify here complete_overlap means the requested slice is a subset of the storage.

I think this is reasonable since the client has knowledge in advanced of whats being requested?
There's a consistency model here - metadata and actual storage. We always peek into metadata first, then do the actual fetch. So if metadata never lies, then actual data fetch should be very smooth.



# A dev print util.
def color_print(s, color=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit: we're accumulating a lot in utils.py. Might be worth considering placing this under example for now, or we can think a bit about the overall folder structure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is very personalized. I just removed it since it's not used.

full_tensor,
request.tensor_slice.local_shape,
request.tensor_slice.offsets,
# Strored object is a DTensor. Return full tensor if
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert this is the case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LucasLLC and others added 6 commits October 14, 2025 08:25
* latest rdma updates from monarch

* remove test code

* remove test code
#56)

* latest rdma updates from monarch

* remove test code

* remove test code

* working v1

* removing test code

* v1

* add v1 gate

* nits

* linter
@kaiyuan-li kaiyuan-li merged commit 2f2902c into main Oct 14, 2025
5 checks passed
@kaiyuan-li kaiyuan-li deleted the lky_get_optimization branch October 14, 2025 16:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants