Skip to content

Commit 86deef3

Browse files
authored
Add SarBp optimizations and accuracy improvements (#1140)
* Add SarBp optimizations and accuracy improvements This PR includes several SarBp performance optimizations for the fltflt case: - Adds fltflt_sqrt_fast(), which uses fewer operations at a very slight accuracy cost. - Adds fltflt_norm3d(), which uses fewer normalizations and calls fltflt_sqrt_fast() - Splits the calculation of bin such that more terms can be precomputed and stored in shared memory. Reduced inner loop bin calculation from ~24 FLOPs to ~18. In addition, this PR adjusts he computation of the weight w to preserve more bits. Previously, the mixed and fltflt implementations computed bin as: bin = static_cast<loose_compute_t>(diffR * dr_inv) + bin_offset; w = bin - ::floor(bin) However, with large bin counts, this can leave relatively few bits of precision for w. The fltflt and mixed variants have been adjusted to preserve more accuracy at the cost of performance. All told, the fltflt version is ~15% faster due to the optimizations, but the mixed-precision version is slower due to increased use of FP64. In the future, a new option may be added to reduce the precision of the bin calculation for scenarios/ranges where that makes sense. Signed-off-by: Thomas Benson <tbenson@nvidia.com>
1 parent 5bf9643 commit 86deef3

File tree

7 files changed

+467
-47
lines changed

7 files changed

+467
-47
lines changed

bench/00_misc/fltflt_arithmetic.cu

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,157 @@ NVBENCH_BENCH_TYPES(fltflt_bench_sqrt, NVBENCH_TYPE_AXES(precision_types))
409409
.add_int64_power_of_two_axis("Array Size", nvbench::range(24, 24, 1))
410410
.add_int64_axis("Iterations", {250});
411411

412+
//==============================================================================
413+
// Square Root Fast Benchmark
414+
// For float/double, this is identical to the sqrt benchmark (sqrtf/sqrt).
415+
// For fltflt, this uses fltflt_sqrt_fast instead of fltflt_sqrt.
416+
//==============================================================================
417+
template <typename T>
418+
__global__ void iterative_sqrt_fast_kernel(T* __restrict__ result, int64_t size, int32_t iterations)
419+
{
420+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
421+
if (idx < size) {
422+
T val[ILP_FACTOR];
423+
const T init_val = static_cast<T>(2.718281828);
424+
425+
#pragma unroll
426+
for (int ilp = 0; ilp < ILP_FACTOR; ilp++) {
427+
if constexpr (std::is_same_v<T, fltflt>) {
428+
val[ilp] = fltflt_sqrt_fast(init_val);
429+
} else {
430+
val[ilp] = sqrt(init_val);
431+
}
432+
}
433+
434+
#pragma unroll ITER_UNROLL_FACTOR
435+
for (int32_t i = 1; i < iterations; i++) {
436+
#pragma unroll
437+
for (int ilp = 0; ilp < ILP_FACTOR; ilp++) {
438+
if constexpr (std::is_same_v<T, fltflt>) {
439+
val[ilp] = fltflt_sqrt_fast(val[ilp]);
440+
} else {
441+
val[ilp] = sqrt(val[ilp]);
442+
}
443+
}
444+
}
445+
446+
T result_val = val[0];
447+
#pragma unroll
448+
for (int ilp = 1; ilp < ILP_FACTOR; ilp++) {
449+
result_val = result_val + val[ilp];
450+
}
451+
result[idx] = result_val;
452+
}
453+
}
454+
455+
template <typename PrecisionType>
456+
void fltflt_bench_sqrt_fast(nvbench::state &state, nvbench::type_list<PrecisionType>)
457+
{
458+
const index_t size = static_cast<index_t>(state.get_int64("Array Size"));
459+
const int32_t iterations = static_cast<int32_t>(state.get_int64("Iterations"));
460+
cudaExecutor exec{0};
461+
462+
auto result = make_tensor<PrecisionType>({size});
463+
464+
state.add_element_count(size, "NumElements");
465+
state.add_global_memory_writes<PrecisionType>(size);
466+
467+
constexpr int block_size = 256;
468+
int grid_size = static_cast<int>((size + block_size - 1) / block_size);
469+
470+
exec.sync();
471+
472+
state.exec([&](nvbench::launch &launch) {
473+
iterative_sqrt_fast_kernel<<<grid_size, block_size, 0, (cudaStream_t)launch.get_stream()>>>(
474+
result.Data(), size, iterations);
475+
});
476+
}
477+
478+
NVBENCH_BENCH_TYPES(fltflt_bench_sqrt_fast, NVBENCH_TYPE_AXES(precision_types))
479+
.add_int64_power_of_two_axis("Array Size", nvbench::range(24, 24, 1))
480+
.add_int64_axis("Iterations", {250});
481+
482+
//==============================================================================
483+
// 3D Norm Benchmark: sqrt(dx^2 + dy^2 + dz^2)
484+
// Each ILP lane has distinct dx values that depend on the previous iteration's
485+
// result, creating a true dependency chain that prevents CSE across lanes.
486+
//==============================================================================
487+
template <typename T>
488+
__global__ void iterative_norm3d_kernel(T* __restrict__ result, int64_t size, int32_t iterations)
489+
{
490+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
491+
if (idx < size) {
492+
// Per-lane dx values create independent dependency chains
493+
T dx[ILP_FACTOR];
494+
const T dy = static_cast<T>(-487293.18274);
495+
const T dz = static_cast<T>(183649.27391);
496+
497+
#pragma unroll
498+
for (int ilp = 0; ilp < ILP_FACTOR; ilp++) {
499+
dx[ilp] = static_cast<T>(312847.91837) + static_cast<T>(ilp * 0.1);
500+
}
501+
502+
#pragma unroll ITER_UNROLL_FACTOR
503+
for (int32_t i = 0; i < iterations; i++) {
504+
#pragma unroll
505+
for (int ilp = 0; ilp < ILP_FACTOR; ilp++) {
506+
T norm;
507+
if constexpr (std::is_same_v<T, fltflt>) {
508+
norm = fltflt_norm3d(dx[ilp], dy, dz);
509+
} else {
510+
norm = sqrt(dx[ilp] * dx[ilp] + dy * dy + dz * dz);
511+
}
512+
// Feed result back into dx to create a dependency chain.
513+
// Add the computed norm and subtract off the approximate
514+
// expected norm to keep dx in a stable range while preventing
515+
// the compiler from optimizing away the computation.
516+
if constexpr (std::is_same_v<T, fltflt>) {
517+
// fltflt addition/subtraction is expensive and we do not want to bias the benchmark
518+
// too much, so at least keep the expected norm as a float rather than fltflt to
519+
// reduce the cost of the subtraction.
520+
dx[ilp] = dx[ilp] + (norm - 607499.4f);
521+
} else {
522+
dx[ilp] = dx[ilp] + (norm - static_cast<T>(607499.4));
523+
}
524+
}
525+
}
526+
527+
T result_val = dx[0];
528+
#pragma unroll
529+
for (int ilp = 1; ilp < ILP_FACTOR; ilp++) {
530+
result_val = result_val + dx[ilp];
531+
}
532+
result[idx] = result_val;
533+
}
534+
}
535+
536+
template <typename PrecisionType>
537+
void fltflt_bench_norm3d(nvbench::state &state, nvbench::type_list<PrecisionType>)
538+
{
539+
const index_t size = static_cast<index_t>(state.get_int64("Array Size"));
540+
const int32_t iterations = static_cast<int32_t>(state.get_int64("Iterations"));
541+
cudaExecutor exec{0};
542+
543+
auto result = make_tensor<PrecisionType>({size});
544+
545+
state.add_element_count(size, "NumElements");
546+
state.add_global_memory_writes<PrecisionType>(size);
547+
548+
constexpr int block_size = 256;
549+
int grid_size = static_cast<int>((size + block_size - 1) / block_size);
550+
551+
exec.sync();
552+
553+
state.exec([&](nvbench::launch &launch) {
554+
iterative_norm3d_kernel<<<grid_size, block_size, 0, (cudaStream_t)launch.get_stream()>>>(
555+
result.Data(), size, iterations);
556+
});
557+
}
558+
559+
NVBENCH_BENCH_TYPES(fltflt_bench_norm3d, NVBENCH_TYPE_AXES(precision_types))
560+
.add_int64_power_of_two_axis("Array Size", nvbench::range(24, 24, 1))
561+
.add_int64_axis("Iterations", {250});
562+
412563
//==============================================================================
413564
// Absolute Value Benchmark
414565
//==============================================================================

bench/scripts/run_fltflt_benchmarks.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,49 @@ def parse_benchmark_output(output, verbose=False):
203203
return results
204204

205205

206+
def parse_benchmark_output_no_type(output, verbose=False):
207+
"""
208+
Parse nvbench output for benchmarks without a type axis (fltflt-only).
209+
Returns a dict with a single 'fltflt' key.
210+
"""
211+
results = {}
212+
output = strip_ansi(output)
213+
lines = output.strip().split('\n')
214+
215+
gpu_time_col_idx = None
216+
for line in lines:
217+
if '|' in line and 'GPU Time' in line:
218+
cols = [col.strip() for col in line.split('|')]
219+
for j, col in enumerate(cols):
220+
if col == 'GPU Time':
221+
gpu_time_col_idx = j
222+
break
223+
if gpu_time_col_idx is not None:
224+
if verbose:
225+
print(f" Found GPU Time at column index {gpu_time_col_idx} in: {line.rstrip()}")
226+
break
227+
228+
if gpu_time_col_idx is None:
229+
print(" Warning: Could not find GPU Time column in output")
230+
return results
231+
232+
for line in lines:
233+
if '|' not in line or 'GPU Time' in line or '---' in line:
234+
continue
235+
cols = [col.strip() for col in line.split('|')]
236+
if len(cols) <= gpu_time_col_idx:
237+
continue
238+
gpu_time_str = cols[gpu_time_col_idx]
239+
gpu_time_ms = parse_time_value(gpu_time_str)
240+
if gpu_time_ms is not None:
241+
if verbose:
242+
print(f" Parsed: type=fltflt, gpu_time_col={gpu_time_str!r}, value={gpu_time_ms:.6f} ms")
243+
results['fltflt'] = gpu_time_ms
244+
break
245+
246+
return results
247+
248+
206249
def format_time(time_ms):
207250
"""Format a time in ms with appropriate precision and units."""
208251
if time_ms is None:
@@ -254,7 +297,7 @@ def print_summary(results, relative):
254297
print("-" * 66)
255298

256299
# Order benchmarks - use the canonical order but only show benchmarks that were actually run
257-
bench_order = ['add', 'sub', 'mul', 'div', 'sqrt', 'abs', 'fma', 'madd', 'round', 'trunc', 'floor', 'fmod', 'cast2dbl', 'cast2fltflt']
300+
bench_order = ['add', 'sub', 'mul', 'div', 'sqrt', 'sqrt_fast', 'norm3d', 'abs', 'fma', 'madd', 'round', 'trunc', 'floor', 'fmod', 'cast2dbl', 'cast2fltflt']
258301
# Filter to only benchmarks present in results
259302
bench_order = [b for b in bench_order if b in results]
260303

@@ -386,7 +429,9 @@ def main():
386429
print()
387430

388431
# List of benchmarks to run
389-
all_benchmarks = ['add', 'sub', 'mul', 'div', 'sqrt', 'abs', 'fma', 'madd', 'round', 'trunc', 'floor', 'fmod', 'cast2dbl', 'cast2fltflt']
432+
all_benchmarks = ['add', 'sub', 'mul', 'div', 'sqrt', 'sqrt_fast', 'norm3d', 'abs', 'fma', 'madd', 'round', 'trunc', 'floor', 'fmod', 'cast2dbl', 'cast2fltflt']
433+
# Benchmarks that only have a fltflt variant (no float/double type axis)
434+
fltflt_only_benchmarks = set()
390435
benchmarks = args.benchmarks if args.benchmarks is not None else all_benchmarks
391436

392437
# Validate user-provided benchmarks
@@ -410,7 +455,10 @@ def main():
410455
continue
411456

412457
# Parse results
413-
results = parse_benchmark_output(output, verbose=args.verbose)
458+
if bench in fltflt_only_benchmarks:
459+
results = parse_benchmark_output_no_type(output, verbose=args.verbose)
460+
else:
461+
results = parse_benchmark_output(output, verbose=args.verbose)
414462

415463
if not results:
416464
print(f" Warning: Could not parse results for {bench}")

include/matx/kernels/fltflt.h

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ struct alignas(8) fltflt {
5353
float hi;
5454
float lo;
5555

56-
// The default constructor does not initialize the components, so the value is indeterminate.
57-
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt() = default;
56+
// The default constructor does not initialize the components, so the value is indeterminate. Some versions of
57+
// nvcc will warn about __host__ and __device__ annotations on default constructors because default
58+
// constructors will not run in all conditions (e.g., in static shared memory CUDA kernel allocations).
59+
__MATX_INLINE__ fltflt() = default;
5860
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ constexpr explicit fltflt(double x)
5961
: hi(static_cast<float>(x)), lo(static_cast<float>(x - static_cast<double>(hi))) {}
6062
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ constexpr explicit fltflt(float x) : hi(x), lo(0.0f) {}
@@ -142,6 +144,10 @@ static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ float fdividef_rn(float a,
142144
static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ float fltflt_rsqrt(float x)
143145
{
144146
#if defined(__CUDA_ARCH__)
147+
// rsqrtf has up to 2 ULP of error. This is less precise than 1.0f / ::sqrtf(x), which
148+
// would be 0.5 ULP of error. We currently use rsqrtf() because it is significantly faster
149+
// while maintaining 44+ bits of precision in testing thus far, but we may need to revisit
150+
// this in the future.
145151
return rsqrtf(x);
146152
#else
147153
return 1.0f / ::sqrtf(x);
@@ -519,6 +525,60 @@ static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt fltflt_sqrt(fltflt a
519525
return fltflt_add(prod, yn);
520526
}
521527

528+
// fltflt_sqrt_fast() is a faster approximation of fltflt_sqrt() that uses a single FMA to
529+
// compute the residual a - yn^2 instead of full fltflt subtraction. The FMA computes
530+
// a.hi - yn*yn exactly (exact multiply, single rounding), and adding a.lo recovers the
531+
// input's low-order bits. The result has precision comparable to fltflt_sqrt for most
532+
// values at roughly 1/5 the cost (~7 FLOPs vs ~35+). We do see differences for some
533+
// inputs. For example, for 1e9*pi + sqrt(2), fltflt_sqrt() matches the fp64
534+
// baseline in all mantissa bits and fltflt_sqrt_fast() matches the first 45 mantissa bits.
535+
// This function may eventually become the default sqrt() implementation.
536+
static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt fltflt_sqrt_fast(fltflt a) {
537+
const float xn = (a.hi == 0.0f) ? 0.0f : detail::fltflt_rsqrt(a.hi);
538+
const float yn = detail::fmul_rn(a.hi, xn);
539+
const float residual = detail::fadd_rn(
540+
detail::fmaf_rn(-yn, yn, a.hi), a.lo);
541+
const float correction = detail::fmul_rn(
542+
detail::fmul_rn(xn, 0.5f), residual);
543+
return fltflt_fast_two_sum(yn, correction);
544+
}
545+
546+
// fltflt_norm3d() computes sqrt(dx^2 + dy^2 + dz^2) with minimal intermediate
547+
// normalizations. Instead of the separate fltflt_mul + fltflt_fma + fltflt_fma + fltflt_sqrt_fast
548+
// chain (5 normalizations, ~50 ops), this function computes all three exact squares,
549+
// accumulates with a single normalization, and applies fltflt_sqrt_fast (~39 ops).
550+
// The three inputs are assumed to be normalized fltflt values.
551+
static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt fltflt_norm3d(fltflt dx, fltflt dy, fltflt dz) {
552+
// Exact squares of hi components (each captures full rounding error)
553+
const fltflt px = fltflt_two_prod_fma(dx.hi, dx.hi);
554+
const fltflt py = fltflt_two_prod_fma(dy.hi, dy.hi);
555+
const fltflt pz = fltflt_two_prod_fma(dz.hi, dz.hi);
556+
557+
// Sum the three .hi values using two_sum to capture rounding errors
558+
const fltflt s = fltflt_two_sum(px.hi, py.hi);
559+
const fltflt t = fltflt_two_sum(s.hi, pz.hi);
560+
561+
// Accumulate all eight low-order terms into a single float:
562+
// - two_sum rounding errors: s.lo, t.lo
563+
// - two_prod_fma error terms: px.lo, py.lo, pz.lo
564+
// - cross terms from squaring: 2*dx.hi*dx.lo, 2*dy.hi*dy.lo, 2*dz.hi*dz.lo
565+
// All terms are O(eps) relative to t.hi, so their sum is at most 8*eps*|t.hi|.
566+
// This may result in slight precision loss due to potential overlap between
567+
// lo and t.hi, but this should still be valid for ~44 bits prior to the sqrt.
568+
float lo = detail::fadd_rn(t.lo, s.lo);
569+
lo = detail::fadd_rn(lo, px.lo);
570+
lo = detail::fadd_rn(lo, py.lo);
571+
lo = detail::fadd_rn(lo, pz.lo);
572+
lo = detail::fmaf_rn(detail::fadd_rn(dx.hi, dx.hi), dx.lo, lo);
573+
lo = detail::fmaf_rn(detail::fadd_rn(dy.hi, dy.hi), dy.lo, lo);
574+
lo = detail::fmaf_rn(detail::fadd_rn(dz.hi, dz.hi), dz.lo, lo);
575+
576+
// Single normalization before sqrt
577+
const fltflt sum_sq = fltflt_fast_two_sum(t.hi, lo);
578+
579+
return fltflt_sqrt_fast(sum_sq);
580+
}
581+
522582
// Scalar sqrt overload so unary operator dispatch can handle fltflt expressions
523583
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt sqrt(fltflt a) { return fltflt_sqrt(a); }
524584

0 commit comments

Comments
 (0)