Skip to content

Commit

Permalink
fixed workspace allocation for TP overlap test with pure GEMM
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Jan 17, 2025
1 parent 6e84892 commit c13a81b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def _fp8_gemm():
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype,
torch.uint8 if opts.fp8_output else torch.bfloat16,
te.module.base.get_workspace(),
te.module.base.get_workspace().repeat(3),
bias=None,
use_bias=False,
gelu=False,
Expand Down Expand Up @@ -639,7 +639,7 @@ def _fp8_gemm2(gemm1_out):
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype,
torch.uint8 if opts.fp8_output else torch.bfloat16,
te.module.base.get_workspace(),
te.module.base.get_workspace().repeat(3),
bias=None,
use_bias=False,
gelu=False,
Expand All @@ -662,7 +662,7 @@ def _gemm():
kernel_t,
gemm_inp,
torch.bfloat16,
te.module.base.get_workspace(),
te.module.base.get_workspace().repeat(3),
bias=None,
use_bias=False,
gelu=False,
Expand Down

0 comments on commit c13a81b

Please sign in to comment.