Skip to content

Commit

Permalink
added missing copy of AG+GEMM input into comm buffer
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 5, 2024
1 parent ec2d5ae commit 8fe3942
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions examples/jax/comm_gemm_overlap/comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from jax.experimental import mesh_utils

import transformer_engine.jax as te
from transformer_engine.jax.cpp_extensions import gemm_impl
from transformer_engine.jax.cpp_extensions import gemm_impl, copy_into_overlap_buffer
from transformer_engine.jax.gemm import (
initialize_comm_gemm_overlaps,
destroy_comm_gemm_overlaps,
Expand Down Expand Up @@ -124,14 +124,15 @@
if myrank == 0:
print(
f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n"
+ f"{myrank}: LHS sharding: {lhs.sharding}\n"
+ f"{myrank}: RHS sharding: {rhs.sharding}\n",
+ f"{myrank}: LHS sharding: {lhs.sharding.spec}\n"
+ f"{myrank}: RHS sharding: {rhs.sharding.spec}\n",
flush=True,
)


@jax.jit
def te_gemm(A, B):
copy_into_overlap_buffer(A, overlap_name, True)
return gemm_impl(
A,
jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding),
Expand All @@ -145,10 +146,9 @@ def te_gemm(A, B):

if myrank == 0:
print(
f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUTS:\n"
+ f"{myrank}: GEMM output: {output.shape} | {output.sharding}\n"
+ f"{myrank}: {'Gathered LHS' if args.comm_type == 'AG' else 'Scattered output:'}: "
+ f"{extra_out.shape} | {extra_out.sharding}\n",
f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT "
+ f"{output.shape}\n"
+ f"{myrank}: Sharding: {output.sharding.spec}\n",
flush=True,
)

Expand Down

0 comments on commit 8fe3942

Please sign in to comment.