Skip to content

Commit ccae43d

Browse files
henrylhtsangmeta-codesync[bot]
authored andcommitted
pass stream for initialize (#5003)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2016 Pull Request resolved: #5003 This doesn't actually do anything. But the debug logs it printed would be better. Reviewed By: sryap Differential Revision: D84372042 fbshipit-source-id: a0bd1f894005144026de55dac3fce11132677c0e
1 parent 61b22f5 commit ccae43d

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_utils.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ static void launch_fmha_op(const typename Operation::Arguments& arguments) {
121121
"This kernel is not supported. Last CUDA error is: ",
122122
cudaGetErrorString(cudaGetLastError()));
123123

124-
status = op.initialize(arguments, workspace.mutable_data_ptr());
124+
status = op.initialize(
125+
arguments,
126+
workspace.mutable_data_ptr(),
127+
at::cuda::getCurrentCUDAStream());
125128
TORCH_CHECK(
126129
status == cutlass::Status::kSuccess,
127130
"Failed to initialize the CUTLASS kernel. Last CUDA error is: ",

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class FMHA {
205205
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
206206
static Status
207207
run(Params& params, cudaStream_t stream = nullptr) {
208-
CUTLASS_TRACE_HOST("FMHA::run()");
208+
CUTLASS_TRACE_HOST("FMHA::run(), stream: " << (stream ? "non-null" : "null"));
209209
dim3 const block = Kernel::get_block_shape();
210210
dim3 const grid = get_grid_shape(params);
211211

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ class Sm100FmhaBwd {
331331
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
332332
static Status
333333
run(Params& params, cudaStream_t stream = nullptr) {
334-
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
334+
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run(), stream: " << (stream ? "non-null" : "null"));
335335

336336
Status result = Status::kSuccess;
337337
result = params.op_sum_OdO.run(stream);

0 commit comments

Comments
 (0)