Skip to content

Commit

Permalink
TP communication overlap: enable the overlap between GEMM chunk at Ho…
Browse files Browse the repository at this point in the history
…pper BF16

Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 committed Nov 4, 2024
1 parent 05c0fb0 commit c4dcd95
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,13 @@ def gemm(
assert (
extra_output_tensor is not None
), "SPLIT_PIPELINED_RS requires extra output tensor"
# Disable the overlap between GEMM chunks at ampere and below
major, _ = torch.cuda.get_device_capability()
overlap_gemm_chunks = True if major >= 9 else False
args = tuple(
args
+ (
False,
overlap_gemm_chunks,
extra_output_tensor,
)
)
Expand Down

0 comments on commit c4dcd95

Please sign in to comment.