You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
support get state dict and apply state dict (pytorch#2976)
Summary:
Pull Request resolved: pytorch#2976
X-link: pytorch/FBGEMM#4145
X-link: facebookresearch/FBGEMM#1226
# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.
**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.
# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.
**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
- Loop through chunks of weight and optimizer.
- Concatenate weight and optimizer together.
- Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
- Set optimizer to device tensor based on ID.
- Write ID weight to backend for each table.
# Limitations
The current solution has two limitations:
- Memory overhead:
- When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
- With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.
# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.
Reviewed By: bobbyliujb
Differential Revision: D74790154
0 commit comments