Skip to content

Why does ReplicateShardedData perform an uncached compile+device execute on each sharded output to return to host #9751

@jameszianxuTT

Description

@jameszianxuTT

We've worked around the sharded output issue #9726, and are now able to produce sharded outputs using the shardy path introduced in #9348.

However, we observe that gathering sharded outputs to host (eg. sharded_tensor.to("cpu")) uses the ReplicateShardedData function, which takes a sharded output, then compiles and executes an effective no-op computation that tags its output as replicated, in order to gather a sharded tensor into a replicated tensor.

We have a few questions about this implementation:

  1. At the point of ReplicateShardedData, it seems like we have enough information from the OpSharding to reassemble the shards on host, rather than dispatching a device execute. Why aren't shards pulled to host and reassembled in host memory instead? Is this just for convenience (i.e. reusing existing unsharding infrastructure) or is there a more fundamental reason?
  2. Why isn't the compilation for the x+=0 graph in ReplicateShardedData cached in the computation cache?

A drawback of the existing implementation described in (1) is that it forces an expensive on-device all-gather which seems unnecessary, and may overrun device memory if the tensor being all gathered (eg. a large KV cache) actually cannot fit in device memory.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions