Skip to content

[mosaic-gpu] add multicast ptr support to TMA with overlapped gemm and all reduce examples #28679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _copy_smem_to_gmem_lowering(
has_user_predicate,
commit_group,
reduction_op,
team_id,
):
if has_user_predicate:
flat_args, user_predicate = flat_args[:-1], flat_args[-1]
Expand Down Expand Up @@ -268,6 +269,7 @@ def _copy_smem_to_gmem_lowering(
predicate=predicate,
arrive=commit_group,
reduction_op=reduction_op,
team_id=team_id,
**copy_params,
)
return ()
Expand Down Expand Up @@ -347,6 +349,7 @@ def copy_smem_to_gmem(
*,
commit_group: bool = True,
reduction_op: mgpu.ReductionOp | None = None,
team_id: int | None = None,
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.

Expand All @@ -361,6 +364,7 @@ def copy_smem_to_gmem(
reduction_op: If set, perform the specified reduction operation when storing
to GMEM. For example, using ``"add"`` is conceptually equivalent to
doing ``src += dst``.
team_id: if set, dst ref would be translated to a multicast memory addr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does a Pallas user get control over teams? It's not a JAX-level concept so it doesn't make sense to surface it here. How does XLA manage that?

I think what you can do on the JAX level is take an axis name, and perform the reduction along that JAX mesh axis.

Finally: does it ever make sense to use team_id without reduction_op? If not, we should add checks


See also:
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
Expand Down Expand Up @@ -389,6 +393,7 @@ def copy_smem_to_gmem(
has_user_predicate=predicate is not None,
commit_group=commit_group,
reduction_op=reduction_op,
team_id = team_id,
)
return None

Expand Down
19 changes: 16 additions & 3 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,14 @@ def supports_cross_device_collectives():


@mosaic_gpu_p.def_abstract_eval
def _mosaic_gpu_abstract_eval(*_, module, out_types):
def _mosaic_gpu_abstract_eval(
*_,
module,
out_types,
input_output_aliases,
):
del module # Unused.
del input_output_aliases # Unused.
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]

# TODO(apaszke): Implement a proper system for managing kernel lifetimes
Expand Down Expand Up @@ -618,8 +624,9 @@ def _run_serde_pass(
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty, i32]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of changing the signature of this function, you could just add a new function in runtime.cc and call it to translate a regular pointer to an mc pointer. It doesn't have to be bundled with the TMA desc initialization

init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
Expand All @@ -639,6 +646,7 @@ def as_gpu_kernel(
kernel_name: str | None = None,
ir_version: int | None = None,
thread_semantics: LoweringSemantics = LoweringSemantics.Lane,
input_output_aliases: tuple[tuple[int, int], ...] = (),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of input_output_aliases is a separate change. Please send an independent PR for this

):
if isinstance(in_shape, list):
in_shape = tuple(in_shape)
Expand Down Expand Up @@ -680,7 +688,12 @@ def _check_args(*args):
)

def bind(*args) -> Any:
return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape)
return mosaic_gpu_p.bind(
*args,
module=module,
out_types=out_shape,
input_output_aliases=input_output_aliases,
)

if prof_spec is not None:
@jax.jit
Expand Down
Loading