-
Notifications
You must be signed in to change notification settings - Fork 3k
[mosaic-gpu] add multicast ptr support to TMA with overlapped gemm and all reduce examples #28679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please split off the examples into another PR. Adding a Pallas or Mosaic feature should be done on its own, with appropriate tests in tests/mosaic/gpu_test.py
and tests/pallas/mosaic_gpu_test.py
(which are missing here)
@@ -361,6 +364,7 @@ def copy_smem_to_gmem( | |||
reduction_op: If set, perform the specified reduction operation when storing | |||
to GMEM. For example, using ``"add"`` is conceptually equivalent to | |||
doing ``src += dst``. | |||
team_id: if set, dst ref would be translated to a multicast memory addr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does a Pallas user get control over teams? It's not a JAX-level concept so it doesn't make sense to surface it here. How does XLA manage that?
I think what you can do on the JAX level is take an axis name, and perform the reduction along that JAX mesh axis.
Finally: does it ever make sense to use team_id without reduction_op? If not, we should add checks
@@ -639,6 +646,7 @@ def as_gpu_kernel( | |||
kernel_name: str | None = None, | |||
ir_version: int | None = None, | |||
thread_semantics: LoweringSemantics = LoweringSemantics.Lane, | |||
input_output_aliases: tuple[tuple[int, int], ...] = (), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The addition of input_output_aliases is a separate change. Please send an independent PR for this
llvm.inline_asm( | ||
i32, | ||
[mc_ptr, x, y, z, w], | ||
"multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a broadcast?
return return_regs[0], return_regs[1], return_regs[2], return_regs[3] | ||
|
||
|
||
def multimem_st_128(mc_ptr, x, y, z, w): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check that the args have the f32 type
def multimem_st_128(mc_ptr, x, y, z, w): | ||
i32 = ir.IntegerType.get_signless(32) | ||
llvm.inline_asm( | ||
i32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This snippet does not have a result. Use ir.Type.parse("!llvm.void")
and remove the result register constraint
return_regs = [ | ||
llvm.extractvalue(i32, return_struct, [i]) for i in range(4) | ||
] | ||
return return_regs[0], return_regs[1], return_regs[2], return_regs[3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you use f16x2
then you should return the results after bitcasting them to ir.VectorType.get((2,), bf16)
. You might be able to use the bitcast
from this file
) | ||
|
||
|
||
def wait_loop(uc_ptr, num_gpus=8, is_relaxed=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have semaphores in Pallas. Is that not enough?
@@ -1365,3 +1365,80 @@ def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: | |||
result = vector.insertelement(elem, result, position=c(offset + i, index)) | |||
offset += vty.shape[0] | |||
return result | |||
|
|||
|
|||
def signal_with_red(mc_ptr, is_relaxed=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this function?
i64 = ir.IntegerType.get_signless(64) | ||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] | ||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty, i32] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of changing the signature of this function, you could just add a new function in runtime.cc and call it to translate a regular pointer to an mc pointer. It doesn't have to be bundled with the TMA desc initialization
#28941 might be relevant to you btw (it adds support for TMA to remote addresses) |
copy_smem_to_gmem
NVSHMEM_TEAM_WORLD
this PR is dependent on #28595 to use a newer version of PTX ISA. while compiling the kernels, jax would pick PTX ISA 8.0 which does not allow us to use
multimem.ld_reduce
cc @apaszke