forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New and improved _shard_device_array function. (jax-ml#2958)
This gets the performance of sharding DeviceArray arguments to pmap roughly back to what it was prior to jax-ml@07571ae. It does so by re-introducing a _shard_device_array function that can handle arbitrary array slices. Benchmark results compared to jax-ml@87d9590 (i.e. just prior to the regression): ``` ---------Benchmark summary for pmap_shard_device_array--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- -------- ---------- --------------- 10 8 0.0479975 12.0865 1 1.09631 100 8 0.32916 5.7446 6.85786 1.10263 500 8 1.5563 2.68041 32.4246 1.10066 100 2 0.136431 8.33826 2.84245 1.15886 100 4 0.198815 5.91716 4.1422 1.11409 100 8 0.31788 4.80559 6.62285 1.06637 ``` This still seems a bit slower than it was before, but gets most of the performance back. We can further optimize in future changes if needed. Fixes jax-ml#2958 (hopefully)
- Loading branch information
Showing
2 changed files
with
55 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters