Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/65_distributed_gemm/65_distributed_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@

#include "helper.h"

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

using namespace cute;

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -135,6 +132,9 @@ static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

// Distributed GEMM tiling/sharding schedule
// Choices:
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if ! defined(__CUDA_ARCH_FEAT_SM100_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
Expand Down Expand Up @@ -1810,6 +1813,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
/* no-op */

}
#endif
}

static dim3 get_block_shape() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,9 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if ! defined(__CUDA_ARCH_FEAT_SM100_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
Expand Down Expand Up @@ -1804,6 +1807,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
/* no-op */

}
#endif
}

static dim3 get_block_shape() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
}

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if ! defined(__CUDA_ARCH_FEAT_SM100_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else

TileScheduler tile_scheduler{params.tile_scheduler};

Expand Down Expand Up @@ -612,6 +615,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {

/* no-op, donate regs and exit */
}
#endif
}

};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ struct Sm100FmhaGenKernelWarpspecialized {
}

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if ! defined(__CUDA_ARCH_FEAT_SM100_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else

TileScheduler tile_scheduler{params.tile_scheduler};

Expand Down Expand Up @@ -569,6 +572,7 @@ struct Sm100FmhaGenKernelWarpspecialized {

/* no-op, donate regs and exit */
}
#endif
}

};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) {
#if ! defined(__CUDA_ARCH_FEAT_SM100_ALL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alihassanijr , this only covers 100a but not 100f. You could take a look at launch control header file for the 100f macro.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't using the family macros break builds with CTK 12.8 and earlier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both CUTLASS_ARCH_MMA_SM100F_SUPPORTED and CUTLASS_ARCH_MMA_SM100F_ENABLED are conditioned on CTK >= 12.9, so Sm100 users with CTK 12.8 will wind up with empty kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. so if < 12.9, 100a; else 100a || 100f

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or is CUDA_ARCH_FAMILY(1000) just a false in 12.8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right; defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) should work -- they're already conditioned on the correct CTK compiler version.

printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else

TileScheduler tile_scheduler(params.tile_scheduler);

Expand Down Expand Up @@ -814,6 +817,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
#endif
}

template<class BlkCoord>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@

#include "helper.h"

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

using namespace cute;

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -135,6 +132,9 @@ static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

// Distributed GEMM tiling/sharding schedule
// Choices:
//
Expand Down
4 changes: 4 additions & 0 deletions examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ struct FmhaKernelTma {
}

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
TileScheduler tile_scheduler{params.tile_scheduler};

// Shared memory.
Expand Down Expand Up @@ -216,6 +219,7 @@ struct FmhaKernelTma {
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.epilogue);
#endif
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ struct FmhaKernelTmaWarpSpecialized {

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {

#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
enum class WarpGroupRole {
Producer = 0,
Consumer0 = 1,
Expand Down Expand Up @@ -412,6 +415,7 @@ struct FmhaKernelTmaWarpSpecialized {
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
}
}
#endif
}
};

Expand Down