We've worked around the sharded output issue #9726, and are now able to produce sharded outputs using the shardy path introduced in #9348.
However, we observe that gathering sharded outputs to host (eg. sharded_tensor.to("cpu")) uses the ReplicateShardedData function, which takes a sharded output, then compiles and executes an effective no-op computation that tags its output as replicated, in order to gather a sharded tensor into a replicated tensor.
We have a few questions about this implementation:
- At the point of ReplicateShardedData, it seems like we have enough information from the OpSharding to reassemble the shards on host, rather than dispatching a device execute. Why aren't shards pulled to host and reassembled in host memory instead? Is this just for convenience (i.e. reusing existing unsharding infrastructure) or is there a more fundamental reason?
- Why isn't the compilation for the x+=0 graph in ReplicateShardedData cached in the computation cache?
A drawback of the existing implementation described in (1) is that it forces an expensive on-device all-gather which seems unnecessary, and may overrun device memory if the tensor being all gathered (eg. a large KV cache) actually cannot fit in device memory.