Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A custom op that wraps a FusionDefinition and takes/produces DTensors. #3703

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

wujingyue
Copy link
Collaborator

No description provided.

@wujingyue wujingyue marked this pull request as draft January 14, 2025 06:18
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

Cc @jjsjann123

wujingyue added a commit that referenced this pull request Jan 14, 2025
wujingyue added a commit that referenced this pull request Jan 14, 2025
Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit 89a7dd4)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Potential Bug

The FusionDefinitionWrapper class does not handle exceptions properly. If an exception occurs during the execution of the define_fusion function, it will not be caught and handled.

class FusionDefinitionWrapper:
    def __init__(self, define_fusion: Callable[[FusionDefinition], None]):
        """Wraps a function that defines a fusion without `multidevice_schedule`."""
        self._define_fusion = define_fusion

    def __call__(self, in_dtensors: Iterable[DTensor]) -> list[DTensor]:
        define_fn = self._define_fusion

        class Model(FusionDefinition):
            def definition(self):
                define_fn(self)

            def _find_tensor_by_index(self, index: int) -> nvfuser.Tensor:
                for t in self.sched.tensors():
                    if t.index == index:
                        return t
                return None

            def multidevice_schedule(self):
                for in_tensor_index, in_dtensor in zip(self.inputs(), in_dtensors):
                    in_tensor = self._find_tensor_by_index(in_tensor_index)

                    # Set the device mesh.
                    assert (
                        in_dtensor.device_mesh.ndim == 1
                    ), "nvFuser's Python API only supports 1D meshes."
                    mesh = nvfuser.DeviceMesh(
                        in_dtensor.device_mesh.mesh.view(-1).tolist()
                    )
                    self.sched._set_device_mesh(in_tensor, mesh)

                    # Parallelize.
                    assert len(in_dtensor.placements) == 1, "Expect a 1D mesh"
                    placement: Placement = in_dtensor.placements[0]
                    if placement.is_shard():
                        dim = cast(Shard, placement).dim
                        self.sched.parallelize(
                            in_tensor, dim, nvfuser.ParallelType.mesh_x
                        )

        in_tensors = [in_dtensor.to_local() for in_dtensor in in_dtensors]
        model = Model()
        out_tensors = model.execute(in_tensors)

        out_dtensors = []
        for out_tensor in out_tensors:
            # FIXME: we should collect output meshes/placements from nvFuser.
            out_dtensor = DTensor.from_local(
                out_tensor, in_dtensors[0].device_mesh, in_dtensors[0].placements
            )
            out_dtensors.append(out_dtensor)
        return out_dtensors
Code Smell

The test_plus_one function is not checking if the setup_process_group fixture is properly set up before using it.

@pytest.mark.mpi
def test_plus_one(setup_process_group):
    def define_fusion(fd: FusionDefinition):
        inp = fd.define_tensor(
            (-1, -1), contiguity=(False, False), dtype=DataType.Float
        )
        one = fd.define_scalar(1.0, dtype=DataType.Float)
        out = fd.ops.add(inp, one)
        fd.add_output(out)

    op = FusionDefinitionWrapper(define_fusion)

    num_devices = dist.get_world_size()
    rank = dist.get_rank()
    torch.cuda.set_device(rank)

    in_tensor = torch.randn(num_devices, 4)
    mesh = dist.device_mesh.init_device_mesh("cuda", [num_devices])
    in_dtensor = dist.tensor.distribute_tensor(in_tensor, mesh, [Shard(0)])

    out_dtensors = op([in_dtensor])

    assert len(out_dtensors) == 1
    out_dtensor = out_dtensors[0]
    torch.testing.assert_close(out_dtensor.to_local(), in_dtensor.to_local() + 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant