Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions csrc/jit/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static void unload_library(const LibraryHandle& library) {

static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
const cudaStream_t& stream, const int& smem_size,
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
const dim3& grid_dim, const dim3& block_dim, const dim3& cluster_dim) {
if (smem_size > 0)
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand All @@ -51,12 +51,10 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,

// NOTES: must use `static` or the `attr` will be deconstructed
static LaunchAttrHandle attr;
if (cluster_dim > 1) {
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {static_cast<unsigned>(cluster_dim), 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
}
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {cluster_dim.x, cluster_dim.y, cluster_dim.z};
config.attrs = &attr;
config.numAttrs = 1;
return config;
}

Expand Down Expand Up @@ -95,7 +93,7 @@ static void unload_library(const LibraryHandle& library) {

static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
const cudaStream_t& stream, const int& smem_size,
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
const dim3& grid_dim, const dim3& block_dim, const dim3& cluster_dim) {
if (smem_size > 0)
DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));

Expand All @@ -113,14 +111,12 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,

// NOTES: must use `static` or the `attr` will be deconstructed
static LaunchAttrHandle attr;
if (cluster_dim > 1) {
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = cluster_dim;
attr.value.clusterDim.y = 1;
attr.value.clusterDim.z = 1;
config.attrs = &attr;
config.numAttrs = 1;
}
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = cluster_dim.x;
attr.value.clusterDim.y = cluster_dim.y;
attr.value.clusterDim.z = cluster_dim.z;
config.attrs = &attr;
config.numAttrs = 1;
return config;
}

Expand Down
19 changes: 11 additions & 8 deletions csrc/jit/kernel_runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ struct LaunchArgs {
std::pair<int, int> grid_dim;
int num_threads;
int smem_size;
int cluster_dim;
std::pair<int, int> cluster_dim; // (tma_multicast_dim, split_k_dim)

LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim_x = 1, const int& cluster_dim_y = 1):
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim({cluster_dim_x, cluster_dim_y}) {}

LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim_x = 1, const int& cluster_dim_y = 1):
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim({cluster_dim_x, cluster_dim_y}) {}
};

class KernelRuntime final {
Expand Down Expand Up @@ -102,14 +102,17 @@ class LaunchRuntime {
static_cast<unsigned>(launch_args.grid_dim.second),
1};
const dim3& block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
const dim3& cluster_dim = {static_cast<unsigned>(launch_args.cluster_dim.first),
static_cast<unsigned>(launch_args.cluster_dim.second),
1};
auto config = construct_launch_config(kernel, stream, launch_args.smem_size,
grid_dim, block_dim, launch_args.cluster_dim);
grid_dim, block_dim, cluster_dim);

// Launch in the derived class
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n",
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: {%d, %d}, stream: %ld\n",
launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads,
launch_args.smem_size, launch_args.cluster_dim, stream.id());
launch_args.smem_size, launch_args.cluster_dim.first, launch_args.cluster_dim.second, stream.id());
}
Derived::launch_impl(kernel, config, args);
}
Expand Down
147 changes: 101 additions & 46 deletions csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ struct GemmConfig {
int block_m, block_n, block_k;
int num_stages, num_last_stages;

int k_slices;

// Templated device configs
int num_sms;
int tc_util;
Expand Down Expand Up @@ -99,7 +101,8 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages, const MulticastConfig& multicast_config) {
const int& num_stages, const MulticastConfig& multicast_config,
const bool split_k) {
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));

Expand All @@ -110,7 +113,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;

// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, split_k ? torch::kFloat : cd_dtype);

// A/B shared memory
const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size;
Expand All @@ -122,7 +125,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);

// M-barriers and tensor memory pointers
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages + (split_k ? 1 : 0));
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);

Expand All @@ -146,6 +149,25 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
};
}

template <typename ArchSpec>
static void get_k_slices_options(const GemmType& gemm_type, const int k, const int block_k, std::vector<int>& k_slices_options) {
const auto aligned_k = align(k, block_k);
int max_k_slices = 1;
if (ArchSpec::support_split_k() and (gemm_type != GemmType::KGroupedContiguous)) {
// max potential cluster size for sm90 is 8, queried by cudaOccupancyMaxPotentialClusterSize()
max_k_slices = 8;
}
for (int k_slices = 1; k_slices <= max_k_slices; k_slices *= 2) {
if (aligned_k % k_slices != 0) {
continue;
}
if ((aligned_k / k_slices) % block_k != 0) {
continue;
}
k_slices_options.push_back(k_slices);
}
}

template <typename ArchSpec>
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k, const int& num_groups,
Expand All @@ -170,51 +192,71 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups;
};
const auto& get_num_waves = [=](const int& block_m, const int& block_n) {
return ceil_div(get_num_blocks(block_m, block_n), num_sms);
const auto& get_num_waves = [=](const int& block_m, const int& block_n, const int& k_slices, const int& num_sms_available) {
return ceil_div(k_slices * get_num_blocks(block_m, block_n), num_sms_available);
};
const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) {
const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms;
return num_last_blocks == 0 ? num_sms : num_last_blocks;
// block padding should not be taken into account for last_wave_util
const auto& get_block_padding_cost = [=](const int& block_m, const int& block_n, const int& k_slices) {
return ceil_div((block_m - m % block_m) % block_m * k_slices * num_groups, block_m) * ceil_div(n, block_n)
+ ceil_div((block_n - n % block_n) % block_n * k_slices * num_groups, block_n) * ceil_div(m, block_m);
};
const auto& get_last_wave_util = [=](const int& block_m, const int& block_n, const int& k_slices, const int& num_sms_available) {
const auto& num_last_blocks = (k_slices * get_num_blocks(block_m, block_n)) % num_sms_available;
return (num_last_blocks == 0 ? num_sms_available : num_last_blocks) - get_block_padding_cost(block_m, block_n, k_slices);
};

std::vector<int> k_slices_options;
get_k_slices_options<ArchSpec>(gemm_type, k, block_k, k_slices_options);
int best_k_slices = 1;

// Decide block sizes by waves
int best_block_m = 0, best_block_n = 0;
int best_num_waves = 0, best_last_util = 0;
for (const auto& block_m: block_ms) {
for (const auto& block_n: block_ns) {
const int& num_waves = get_num_waves(block_m, block_n);
const auto& last_util = get_last_wave_util(block_m, block_n);
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k))
continue;

bool success = false;
if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) {
success = true;
} else if (num_waves == best_num_waves) {
// Check last wave utilization
success = last_util > best_last_util;
if (last_util == best_last_util) {
// Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n;
// Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m;
// Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
// NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case
success |= block_m != best_block_m and block_n > best_block_n
and block_n <= n and block_m <= m;
for (const auto& k_slices: k_slices_options) {
// Number of concurrently available SMs is limited by cluster size.
int num_sms_available = ArchSpec::get_num_sms_available_by_cluster_size(k_slices, num_sms);
const int& num_waves = get_num_waves(block_m, block_n, k_slices, num_sms_available);
const auto& last_util = get_last_wave_util(block_m, block_n, k_slices, num_sms_available);
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k))
continue;

bool success = false;
if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) {
success = true;
} else if (num_waves == best_num_waves) {
// Check last wave utilization
success = last_util > best_last_util;
if (last_util == best_last_util) {
// smaller k_slices is better
success = k_slices < best_k_slices;
if (k_slices == best_k_slices) {
// Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n;
// Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m;
// Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
// NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case
success |= block_m != best_block_m and block_n > best_block_n
and block_n <= n and block_m <= m;
}
}
}
}

// Replace with the new config if successful
if (success) {
best_block_m = block_m, best_block_n = block_n;
best_num_waves = num_waves, best_last_util = last_util;
// Replace with the new config if successful
if (success) {
best_block_m = block_m, best_block_n = block_n;
best_num_waves = num_waves, best_last_util = last_util;
best_k_slices = k_slices;
}
}
}
}
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);

int k_partitioned = k / best_k_slices;

// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, false};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
Expand All @@ -223,41 +265,53 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
bool order[2] = {false, true};
if (best_block_m > best_block_n)
std::swap(order[0], order[1]);
for (const bool& is_multicast_on_a: order) {
if (m >= 512 and is_legal[static_cast<int>(is_multicast_on_a)]) {
best_multicast_config = {2, is_multicast_on_a};
break;
// limited by cluster size
if (best_k_slices <= 1 or
best_k_slices * get_num_blocks(best_block_m, best_block_n) <= ArchSpec::get_num_sms_available_by_cluster_size(best_k_slices * 2, num_sms)) {
for (const bool& is_multicast_on_a: order) {
if (m >= 512 and is_legal[static_cast<int>(is_multicast_on_a)]) {
best_multicast_config = {2, is_multicast_on_a};
break;
}
}
}

// Always pick the largest number of stage
constexpr int smem_capacity = ArchSpec::smem_capacity;
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;

for (int num_stages = 12; num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;

best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
m, n, k,
m, n, k_partitioned,
best_block_m, best_block_n, block_k,
major_a, major_b,
ab_dtype, cd_dtype,
num_stages, best_multicast_config);
num_stages, best_multicast_config,
best_k_slices > 1);
if (best_smem_config.smem_size <= smem_capacity) {
best_num_stages = num_stages;
break;
}
}
DG_HOST_ASSERT(best_num_stages != 0);

int num_min_sms = ArchSpec::get_num_sms_available_by_cluster_size(best_k_slices * best_multicast_config.num_multicast, num_sms);
// Recompute the minimal number of SMs required
// NOTES: less L2 cache usage and less GPU frequency drop
int num_min_sms = num_sms;
if (ArchSpec::should_minimize_num_sms()) {
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
DG_HOST_ASSERT(num_min_sms <= num_sms);
if (best_k_slices > 1) {
int num_partitioned_blocks = best_k_slices * get_num_blocks(best_block_m, best_block_n);
DG_HOST_ASSERT(num_partitioned_blocks <= num_min_sms);
num_min_sms = num_partitioned_blocks;
} else {
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
DG_HOST_ASSERT(num_min_sms <= num_sms);
}
}

const auto& config = GemmConfig {
Expand All @@ -272,7 +326,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
.block_n = best_block_n,
.block_k = block_k,
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_last_stages = ceil_div(k_partitioned, block_k) % best_num_stages,
.k_slices = best_k_slices,
.num_sms = num_min_sms,
.tc_util = device_runtime->get_tc_util(),
.multicast_config = best_multicast_config,
Expand All @@ -291,12 +346,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
ab_dtype, cd_dtype, with_accumulation, num_sms);
static std::set<decltype(key)> printed;
if (printed.count(key) == 0) {
printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, k_slices: %d, groups: %d, "
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
"swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n",
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, best_k_slices, num_groups,
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
Expand Down
Loading