Skip to content

[mosaic-gpu] add multicast ptr support to TMA with overlapped gemm and all reduce examples #28679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Amir-19
Copy link
Contributor

@Amir-19 Amir-19 commented May 12, 2025

  • added multicast(mc) ptr support to copy_smem_to_gmem
  • mc ptr currently only supports team 0 which is NVSHMEM_TEAM_WORLD
  • added binding to get number of SMs for gpu, this is used in our examples to launch persistent kernels.
  • add simple GEMM examples overlapped with all reduce. both of one-shot and two-shot GEMM+AR.

this PR is dependent on #28595 to use a newer version of PTX ISA. while compiling the kernels, jax would pick PTX ISA 8.0 which does not allow us to use multimem.ld_reduce

cc @apaszke

Copy link
Member

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please split off the examples into another PR. Adding a Pallas or Mosaic feature should be done on its own, with appropriate tests in tests/mosaic/gpu_test.py and tests/pallas/mosaic_gpu_test.py (which are missing here)

@@ -361,6 +364,7 @@ def copy_smem_to_gmem(
reduction_op: If set, perform the specified reduction operation when storing
to GMEM. For example, using ``"add"`` is conceptually equivalent to
doing ``src += dst``.
team_id: if set, dst ref would be translated to a multicast memory addr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does a Pallas user get control over teams? It's not a JAX-level concept so it doesn't make sense to surface it here. How does XLA manage that?

I think what you can do on the JAX level is take an axis name, and perform the reduction along that JAX mesh axis.

Finally: does it ever make sense to use team_id without reduction_op? If not, we should add checks

@@ -639,6 +646,7 @@ def as_gpu_kernel(
kernel_name: str | None = None,
ir_version: int | None = None,
thread_semantics: LoweringSemantics = LoweringSemantics.Lane,
input_output_aliases: tuple[tuple[int, int], ...] = (),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of input_output_aliases is a separate change. Please send an independent PR for this

llvm.inline_asm(
i32,
[mc_ptr, x, y, z, w],
"multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a broadcast?

return return_regs[0], return_regs[1], return_regs[2], return_regs[3]


def multimem_st_128(mc_ptr, x, y, z, w):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that the args have the f32 type

def multimem_st_128(mc_ptr, x, y, z, w):
i32 = ir.IntegerType.get_signless(32)
llvm.inline_asm(
i32,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet does not have a result. Use ir.Type.parse("!llvm.void") and remove the result register constraint

return_regs = [
llvm.extractvalue(i32, return_struct, [i]) for i in range(4)
]
return return_regs[0], return_regs[1], return_regs[2], return_regs[3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use f16x2 then you should return the results after bitcasting them to ir.VectorType.get((2,), bf16). You might be able to use the bitcast from this file

)


def wait_loop(uc_ptr, num_gpus=8, is_relaxed=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have semaphores in Pallas. Is that not enough?

@@ -1365,3 +1365,80 @@ def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
result = vector.insertelement(elem, result, position=c(offset + i, index))
offset += vty.shape[0]
return result


def signal_with_red(mc_ptr, is_relaxed=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this function?

i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty, i32]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of changing the signature of this function, you could just add a new function in runtime.cc and call it to translate a regular pointer to an mc pointer. It doesn't have to be bundled with the TMA desc initialization

@apaszke apaszke self-assigned this May 22, 2025
@apaszke
Copy link
Member

apaszke commented May 22, 2025

#28941 might be relevant to you btw (it adds support for TMA to remote addresses)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants