Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ class MemoryEfficientAttentionNormalize {
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;

ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
auto s_prime = s_prime_[row];
auto scale = s_prime == 0 ? 0 : 1 / s_prime;
ElementCompute alpha = isLast ? scale : 1;
ElementCompute beta = alpha * m_prime_[row];

intermediate = mul_add_source(beta, converted_source); // X = beta * C
Expand Down Expand Up @@ -209,7 +211,9 @@ class MemoryEfficientAttentionNormalize {
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;

ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
auto s_prime = s_prime_[row];
auto scale = s_prime == 0 ? 0 : 1 / s_prime;
ElementCompute alpha = isLast ? scale : 1;

intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform
Expand Down
18 changes: 17 additions & 1 deletion csrc/include/natten/cuda/fna/fna_backward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ void fna_backward_generic(
float attn_scale,
IntTuple query_tile_shape,
IntTuple key_tile_shape,
IntTuple num_splits_key) {
IntTuple num_splits_key,
bool has_dot_product_min,
bool has_dot_product_max,
float dot_product_min,
float dot_product_max) {
static constexpr auto kRank =
std::tuple_size<decltype(spatial_extent)>::value;
using Dim = typename GetDim<kRank>::type;
Expand Down Expand Up @@ -157,6 +161,18 @@ void fna_backward_generic(

p.num_splits_key = tuple_to_na_dim<Dim>(num_splits_key);

// Optional dot product clipping
p.has_dot_product_clip = has_dot_product_min || has_dot_product_max;
p.has_dot_product_min = has_dot_product_min;
p.has_dot_product_max = has_dot_product_max;
if (has_dot_product_min) {
p.dot_product_min = dot_product_min;
}
if (has_dot_product_max) {
p.dot_product_max = dot_product_max;
}
//

int64_t size_bytes = p.workspace_size();
if (size_bytes) {
void* workspace_ptr = nullptr;
Expand Down
18 changes: 17 additions & 1 deletion csrc/include/natten/cuda/fna/fna_forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ void fna_forward_generic(
float attn_scale,
void* logsumexp_ptr,
IntTuple query_tile_shape,
IntTuple key_tile_shape) {
IntTuple key_tile_shape,
bool has_dot_product_min,
bool has_dot_product_max,
float dot_product_min,
float dot_product_max) {
static constexpr auto kRank =
std::tuple_size<decltype(spatial_extent)>::value;
using Dim = typename GetDim<kRank>::type;
Expand Down Expand Up @@ -167,6 +171,18 @@ void fna_forward_generic(
p.query_tile_shape = tuple_to_na_dim<Dim>(query_tile_shape);
p.key_tile_shape = tuple_to_na_dim<Dim>(key_tile_shape);

// Optional dot product clipping
p.has_dot_product_clip = has_dot_product_min || has_dot_product_max;
p.has_dot_product_min = has_dot_product_min;
p.has_dot_product_max = has_dot_product_max;
if (has_dot_product_min) {
p.dot_product_min = dot_product_min;
}
if (has_dot_product_max) {
p.dot_product_max = dot_product_max;
}
//

if (smem_bytes > 0xc000) {
auto err = cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
Expand Down
40 changes: 39 additions & 1 deletion csrc/include/natten/cuda/fna/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,17 @@ struct FusedNeighborhoodAttentionBackwardKernel {
bool is_fully_block_sparse = false;
bool has_q_padding = false;

// Optional dot product clipping -- all must be set explicitly for avoiding
// comparisons.
bool has_dot_product_clip = false;
bool has_dot_product_min = false;
bool has_dot_product_max = false;
accum_t dot_product_min =
-cutlass::platform::numeric_limits<accum_t>::infinity();
accum_t dot_product_max =
cutlass::platform::numeric_limits<accum_t>::infinity();
//

// Dimensions/strides
int32_t head_dim = -1;
int32_t head_dim_value = -1;
Expand Down Expand Up @@ -1615,7 +1626,6 @@ struct FusedNeighborhoodAttentionBackwardKernel {
mma.set_prologue_done(kPrologueQK);
mma.set_zero_outside_bounds(/*!skipBoundsChecks*/ true);
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);

// Epilogue: add LSE + exp and store that to our shared memory buffer
// shmem <- (matmul_result -
Expand All @@ -1629,6 +1639,34 @@ struct FusedNeighborhoodAttentionBackwardKernel {
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
lane_id, warp_id, output_tile_coords);

// Dot product scale
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);

// (Optional) clip dot products (mask off out of bound dot products)
if (p.has_dot_product_clip) {
if (not p.has_dot_product_max) {
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
accum[i] = accum[i] < p.dot_product_min
? -cutlass::platform::numeric_limits<accum_t>::infinity()
: accum[i];
}
} else if (not p.has_dot_product_min) {
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
accum[i] = accum[i] > p.dot_product_max
? -cutlass::platform::numeric_limits<accum_t>::infinity()
: accum[i];
}
} else {
// assert(p.has_dot_product_min && p.has_dot_product_max);
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
accum[i] =
(accum[i] < p.dot_product_min || accum[i] > p.dot_product_max)
? -cutlass::platform::numeric_limits<accum_t>::infinity()
: accum[i];
}
}
}

if (not p.is_fully_block_sparse) {
// Neighborhood Attention masking
Dim first_col, query_bound, row_idx;
Expand Down
50 changes: 47 additions & 3 deletions csrc/include/natten/cuda/fna/kernel_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,17 @@ struct FusedNeighborhoodAttentionKernel {
bool is_fully_block_sparse = false;
bool has_kv_padding = false;

// Optional dot product clipping -- all must be set explicitly for avoiding
// comparisons.
bool has_dot_product_clip = false;
bool has_dot_product_min = false;
bool has_dot_product_max = false;
accum_t dot_product_min =
-cutlass::platform::numeric_limits<accum_t>::infinity();
accum_t dot_product_max =
cutlass::platform::numeric_limits<accum_t>::infinity();
//

// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block() {
Expand Down Expand Up @@ -734,6 +745,34 @@ struct FusedNeighborhoodAttentionKernel {
MM1::Mma::drain_cp_asyncs();
}

// (Optional) clip dot products (mask off out of bound dot products)
if (p.has_dot_product_clip) {
if (not p.has_dot_product_max) {
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
accum[i] = accum[i] * p.scale;
accum[i] = accum[i] < p.dot_product_min
? -cutlass::platform::numeric_limits<accum_t>::infinity()
: accum[i];
}
} else if (not p.has_dot_product_min) {
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
accum[i] = accum[i] * p.scale;
accum[i] = accum[i] > p.dot_product_max
? -cutlass::platform::numeric_limits<accum_t>::infinity()
: accum[i];
}
} else {
// assert(p.has_dot_product_min && p.has_dot_product_max);
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
accum[i] = accum[i] * p.scale;
accum[i] =
(accum[i] < p.dot_product_min || accum[i] > p.dot_product_max)
? -cutlass::platform::numeric_limits<accum_t>::infinity()
: accum[i];
}
}
}

if (not p.is_fully_block_sparse) {
// Neighborhood Attention masking
Dim first_col, key_bound, row_idx;
Expand Down Expand Up @@ -791,7 +830,7 @@ struct FusedNeighborhoodAttentionKernel {
last_kv_col,
is_first_kv_iter,
iteratorC_tile_offset,
p.scale);
p.has_dot_product_clip ? 1.0 : p.scale);

// Output results to shared-memory

Expand Down Expand Up @@ -967,8 +1006,13 @@ struct FusedNeighborhoodAttentionKernel {
map_index_to_coord((int32_t)thread_id(), problem_size_0_m);
auto query_offset = (query_idx * p.lse_strideM).sum();
if (is_coord_within_upper_bound(query_idx, problem_size_0_m)) {
p.logsumexp_ptr[query_offset] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
if (mi[thread_id()] ==
-cutlass::platform::numeric_limits<accum_t>::infinity()) {
p.logsumexp_ptr[query_offset] = 0.0f;
} else {
p.logsumexp_ptr[query_offset] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
}
//} else if (query_offset < lse_dim) {
// p.logsumexp_ptr[query_offset] =
// cutlass::platform::numeric_limits<accum_t>::infinity();
Expand Down
Loading