-
Notifications
You must be signed in to change notification settings - Fork 5
[WIP] optimize get() to only fetch necessary tensor slice
#36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchstore/client.py
Outdated
| ) -> torch.Tensor: | ||
| """Fetches slices from all volume storages and stitch together to return the whole tensor""" | ||
|
|
||
| # dtensor_slice = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a switch for me to test the old (dtensor_slice=None) and new code path during development.
torchstore/client.py
Outdated
|
|
||
| return assembled_tensor | ||
|
|
||
| def _compute_slice_intersection( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been putting these un utils. We can consider creating a dtensor_utils file
| stored_tensor, stored_slice, request.tensor_slice | ||
| ) | ||
|
|
||
| if extracted_tensor is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does it mean for extracted tensor to be None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The workflow has 2 steps
- in client, we try to find the overlapped tensor slice
- we send the overlapped tensor slice to volume. So when volume is fetching, it uses the slice info passed in from client. Theoretically, client the passed in slice MUST have an overlap with what's stored in the volume. If there's no overlap, the
extracted_tensorwill beNonehere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we raise an error in this case? It sounds like at this point the code expects the shard to exist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, will address in follow up to request all tensors inside the whole volume.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! Tysm!
| tensor_slice, dtensor_slice | ||
| ) | ||
|
|
||
| if tensor_slice is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me. Another thing to consider -- if we've already fetched the entire tensor slice region we can avoid doing so. DTensor is often replicated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll create an issue for this and follow up. Probably need a discussion with you on how to create a test case with replicated dtensor first.
torchstore/client.py
Outdated
| assert device_mesh_shape == tensor_slice.mesh_shape | ||
|
|
||
| return assemble_global_tensor( | ||
| assembled_tensor = assemble_global_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this safe because we never access regions of the tensor which are not initialized? Is there any danger here of us returning uninitialized memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a more information on this? I'm not sure what you mean by "access regions of the tensor which are not initialized"? I thought if we are able to access the volume info, that means the tensor is already properly initialized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In assemble global tensor, the first thing we do is create a 'torch.empty' tensor of the correct global size --
torchstore/torchstore/utils.py
Line 72 in bef9ba7
| global_tensor = torch.empty( |
Now that I'm thinking about it, for particularly large tensor types the behavior here is somewhat unwanted as well. Ideally we create a tensor of the correct local size, and correct for the offsets. devmate may be able to help with the mapping logic, can you give it a shot?
torchstore/client.py
Outdated
|
|
||
| async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor: | ||
| """Fetches slices from all volume storages and stitch together to return the whole tensor""" | ||
| async def _get_distributed_whole_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename this function now that it's not technically fetching the whole_tensor ?
torchstore/storage_volume.py
Outdated
| f"Tensor slice {request.tensor_slice} not found in any stored shards for {key}" | ||
| ) | ||
|
|
||
| def _extract_tensor_subset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this to be correct in the current implementation, I believe it must always return None if the entire requested_slice is not present.
| await transport_buffer.write_from(shard["tensor"]) | ||
| stored_slice = shard["slice"] | ||
| stored_tensor = shard["tensor"] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What might be nice here is to have an if statement like: requested_tensor_slice is in shard
if not request_slice in shard:
continue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verification added to early return if there's no intersection or the intersection is not exactly the tensor slice in the requestion.
| @@ -0,0 +1,127 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why _ext for filename?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for "extended". Since running the whole test_sharding is too much (>30 tests) and CI will fail. So I picked some sharding tests (~20) to be "test_sharding_basic" and rest of them (more complete) to be "test_sharding_ext".
| TORCHSTORE_RDMA_ENABLED=0 \ | ||
| pytest tests/test_tensor_slice.py \ | ||
| --cov=. --cov-report=xml -vv -s | ||
| - name: Run test_resharding_basic tests with coverage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ty!
|
|
||
| if stored_object_type is ObjectType.TENSOR: | ||
| full_tensor = await self._get_tensor(key) | ||
| # TODO: we should get the part of interest in this branch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this to me? Or rather confirm my understanding is correct:
If the stored object is a tensor, then we always fetch the entire tensor and then slice for the requested spec since the tensor can only be in one storage volume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, your description is accurate.
| intersection_slice is None | ||
| or intersection_slice.local_shape != request.tensor_slice.local_shape | ||
| or intersection_slice.offsets != request.tensor_slice.offsets | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc, we only return here if we have complete overlap between requested tensor slice and storage volume slice?
I think this is reasonable since the client has knowledge in advanced of whats being requested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc, we only return here if we have complete overlap between requested tensor slice and storage volume slice?
Right, to clarify herecomplete_overlapmeans the requested slice is a subset of the storage.
I think this is reasonable since the client has knowledge in advanced of whats being requested?
There's a consistency model here - metadata and actual storage. We always peek into metadata first, then do the actual fetch. So if metadata never lies, then actual data fetch should be very smooth.
torchstore/utils.py
Outdated
|
|
||
|
|
||
| # A dev print util. | ||
| def color_print(s, color=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nit: we're accumulating a lot in utils.py. Might be worth considering placing this under example for now, or we can think a bit about the overall folder structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is very personalized. I just removed it since it's not used.
| full_tensor, | ||
| request.tensor_slice.local_shape, | ||
| request.tensor_slice.offsets, | ||
| # Strored object is a DTensor. Return full tensor if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we assert this is the case here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* latest rdma updates from monarch * remove test code * remove test code
#56) * latest rdma updates from monarch * remove test code * remove test code * working v1 * removing test code * v1 * add v1 gate * nits * linter
Before this change, when fetching a tensor, we always fetch the whole tensor that's stored on each volume, e.g.
Full Tensor:
[[0, 1], [2, 3]]Storage Volume 0:
[[0, 1]]Storage Volume 1:
[[2, 3]]When we want to get a tensor slice of shape
[2, 1]and offset[0, 0], theget()method first fetch the full tensor[[0, 1], [2, 3]]then extract the first column of it. This is not efficient.After this change, we only fetch
[[0]]from volume 0 and[[2]]from volume 1 and assemble them together.Added a
example/dtensor.pyto help development. We can remove that later.