Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/65_distributed_gemm/65_distributed_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@

#include "helper.h"

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

using namespace cute;

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -135,6 +132,9 @@ static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

// Distributed GEMM tiling/sharding schedule
// Choices:
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
Expand Down Expand Up @@ -1810,6 +1813,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
/* no-op */

}
#endif
}

static dim3 get_block_shape() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,9 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
Expand Down Expand Up @@ -1804,6 +1807,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
/* no-op */

}
#endif
}

static dim3 get_block_shape() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
}

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else

TileScheduler tile_scheduler{params.tile_scheduler};

Expand Down Expand Up @@ -612,6 +615,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {

/* no-op, donate regs and exit */
}
#endif
}

};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ struct Sm100FmhaGenKernelWarpspecialized {
}

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else

TileScheduler tile_scheduler{params.tile_scheduler};

Expand Down Expand Up @@ -569,6 +572,7 @@ struct Sm100FmhaGenKernelWarpspecialized {

/* no-op, donate regs and exit */
}
#endif
}

};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else

TileScheduler tile_scheduler(params.tile_scheduler);

Expand Down Expand Up @@ -814,6 +817,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
#endif
}

template<class BlkCoord>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@

#include "helper.h"

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

using namespace cute;

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -135,6 +132,9 @@ static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))

// Distributed GEMM helpers
#include "dist_gemm_helpers.h"

// Distributed GEMM tiling/sharding schedule
// Choices:
//
Expand Down
4 changes: 4 additions & 0 deletions examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ struct FmhaKernelTma {
}

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
TileScheduler tile_scheduler{params.tile_scheduler};

// Shared memory.
Expand Down Expand Up @@ -216,6 +219,7 @@ struct FmhaKernelTma {
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.epilogue);
#endif
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ struct FmhaKernelTmaWarpSpecialized {

CUTLASS_DEVICE void operator()(const Params &params, char* smem) {

#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
enum class WarpGroupRole {
Producer = 0,
Consumer0 = 1,
Expand Down Expand Up @@ -412,6 +415,7 @@ struct FmhaKernelTmaWarpSpecialized {
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
}
}
#endif
}
};

Expand Down