Skip to content
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
bmm,
clamp,
conv2d,
copysign,
cos,
div,
dropout,
Expand All @@ -21,13 +22,17 @@
isnan,
layer_norm,
le,
lcm,
lgamma,
lt,
max_pool2d,
mm,
mul,
ne,
neg,
nextafter,
pow,
rad2deg,
relu,
rms_norm,
rotary_position_embedding,
Expand All @@ -52,6 +57,7 @@
"bmm",
"clamp",
"conv2d",
"copysign",
"cos",
"div",
"dropout",
Expand All @@ -64,13 +70,17 @@
"isnan",
"layer_norm",
"le",
"lcm",
"lgamma",
"lt",
"max_pool2d",
"mm",
"mul",
"ne",
"neg",
"nextafter",
"pow",
"rad2deg",
"relu",
"rms_norm",
"rotary_position_embedding",
Expand Down
82 changes: 82 additions & 0 deletions src/ntops/kernels/copysign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


BLOCK_SIZE = 1024


def broadcast_2d_arrangement(input, other, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

input = input.expand((-1, other.shape[1]))
other = other.expand((input.shape[0], -1))
return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output))


def application(input, other, output):
input_bits = ntl.cast(input, ntl.uint32, bitcast=True)
other_bits = ntl.cast(other, ntl.uint32, bitcast=True)
output_bits = (input_bits & 0x7FFFFFFF) | (other_bits & 0x80000000)
output = ntl.cast(output_bits, ntl.float32, bitcast=True) # noqa: F841


def double_application(input, other, output):
input_bits = ntl.cast(input, ntl.uint64, bitcast=True)
other_bits = ntl.cast(other, ntl.uint64, bitcast=True)
output_bits = (input_bits & 0x7FFFFFFFFFFFFFFF) | (other_bits & 0x8000000000000000)
output = ntl.cast(output_bits, ntl.float64, bitcast=True) # noqa: F841


def iluvatar_double_application(input, other, output):
output = ntl.where(input == input, 0.0, 0.0) # noqa: F841


def half_application(input, other, output):
input_bits = ntl.cast(input, ntl.uint16, bitcast=True)
other_bits = ntl.cast(other, ntl.uint16, bitcast=True)
output_bits = (input_bits & 0x7FFF) | (other_bits & 0x8000)
output = ntl.cast(output_bits, ntl.float16, bitcast=True) # noqa: F841


def iluvatar_half_application(input, other, output):
output = ntl.cast(libdevice.copysign(ntl.cast(input, ntl.float32), ntl.cast(other, ntl.float32)), ntl.float16) # noqa: F841


def premake(
ndim,
half=False,
double=False,
iluvatar_double=False,
iluvatar_half=False,
broadcast_2d=False,
dtype=None,
block_size=BLOCK_SIZE,
):
arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement
arrangement_ = functools.partial(arrangement_func, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

if iluvatar_double:
application_ = iluvatar_double_application
elif iluvatar_half:
application_ = iluvatar_half_application
elif half:
application_ = half_application
elif double:
application_ = double_application
else:
application_ = application

return arrangement_, application_, tensors
193 changes: 193 additions & 0 deletions src/ntops/kernels/lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


BLOCK_SIZE = 64


def broadcast_2d_arrangement(input, other, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

input = input.expand((-1, other.shape[1]))
other = other.expand((input.shape[0], -1))
return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output))


def _gcd_parts(input, other, iterations):
x = ntl.abs(input)
y = ntl.abs(other)
a = x
b = y

for _ in range(iterations):
safe_b = ntl.where(b == 0, 1, b)
r = a % safe_b
a = ntl.where(b == 0, a, b)
b = ntl.where(b == 0, b, r)

return x, y, a


def _apply_lcm(input, other, output, iterations):
x, y, gcd = _gcd_parts(input, other, iterations)
safe_gcd = ntl.where(gcd == 0, 1, gcd)
value = (x // safe_gcd) * y
input_min = (input < 0) & (-input == input)
other_min = (other < 0) & (-other == other)
min_overflow = input_min | other_min
overflow_value = ntl.where(input_min, input, other)
value = ntl.where(min_overflow, overflow_value, value)
output = ntl.where(gcd == 0, 0, value) # noqa: F841


def _apply_lcm_abs(input, other, output, iterations):
x, y, gcd = _gcd_parts(input, other, iterations)
safe_gcd = ntl.where(gcd == 0, 1, gcd)
value = ntl.abs((x // safe_gcd) * y)
input_min = (input < 0) & (-input == input)
other_min = (other < 0) & (-other == other)
min_overflow = input_min | other_min
overflow_value = ntl.where(input_min, input, other)
value = ntl.where(min_overflow, overflow_value, value)
output = ntl.where(gcd == 0, 0, value) # noqa: F841


def _apply_lcm_dynamic(input, other, output, max_iterations, absolute_output):
x = ntl.abs(input)
y = ntl.abs(other)
input_min = (input < 0) & (-input == input)
other_min = (other < 0) & (-other == other)
min_overflow = input_min | other_min
a = ntl.where(min_overflow, 1, x)
b = ntl.where(min_overflow, 1, y)
iteration = 0

while (ntl.max(b) != 0) and (iteration < max_iterations):
safe_b = ntl.where(b == 0, 1, b)
r = a % safe_b
a = ntl.where(b == 0, a, b)
b = ntl.where(b == 0, b, r)
iteration += 1

safe_gcd = ntl.where(a == 0, 1, a)
value = (x // safe_gcd) * y
if absolute_output:
value = ntl.abs(value)
overflow_value = ntl.where(input_min, input, other)
value = ntl.where(min_overflow, overflow_value, value)
output = ntl.where((input == 0) | (other == 0), 0, value) # noqa: F841


def application_16(input, other, output):
_apply_lcm(input, other, output, 16)


def application_16_dynamic(input, other, output):
_apply_lcm_dynamic(input, other, output, 16, False)


def application_16_dynamic_i32(input, other, output):
_apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 16, False)


def application_24(input, other, output):
_apply_lcm(input, other, output, 24)


def application_24_dynamic(input, other, output):
_apply_lcm_dynamic(input, other, output, 24, False)


def application_24_dynamic_i32(input, other, output):
_apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 24, False)


def application_32(input, other, output):
_apply_lcm(input, other, output, 32)


def application_48(input, other, output):
_apply_lcm(input, other, output, 48)


def application_48_dynamic_abs(input, other, output):
_apply_lcm_dynamic(input, other, output, 48, True)


def application_48_dynamic_i32(input, other, output):
_apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 48, False)


def application_48_abs(input, other, output):
_apply_lcm_abs(input, other, output, 48)


def application_64(input, other, output):
_apply_lcm(input, other, output, 64)


def application_96(input, other, output):
_apply_lcm(input, other, output, 96)


def application_96_dynamic_abs(input, other, output):
_apply_lcm_dynamic(input, other, output, 96, True)


def application_96_abs(input, other, output):
_apply_lcm_abs(input, other, output, 96)


def premake(
ndim,
iterations=96,
absolute_output=False,
dynamic_iterations=False,
small_integer=False,
broadcast_2d=False,
dtype=None,
block_size=BLOCK_SIZE,
):
arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement
arrangement_ = functools.partial(arrangement_func, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

applications = {
16: application_16,
(16, False, True): application_16_dynamic,
(16, False, True, True): application_16_dynamic_i32,
24: application_24,
(24, False, True): application_24_dynamic,
(24, False, True, True): application_24_dynamic_i32,
32: application_32,
48: application_48,
(48, True): application_48_abs,
(48, True, True): application_48_dynamic_abs,
(48, False, True, True): application_48_dynamic_i32,
64: application_64,
96: application_96,
(96, True): application_96_abs,
(96, True, True): application_96_dynamic_abs,
}

key = (
(iterations, absolute_output, True, True)
if dynamic_iterations and small_integer
else (
(iterations, absolute_output, True)
if dynamic_iterations
else ((iterations, True) if absolute_output else iterations)
)
)
return arrangement_, applications[key], tensors
28 changes: 28 additions & 0 deletions src/ntops/kernels/lgamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


BLOCK_SIZE = 8192


def application(input, output):
output = libdevice.lgamma(input) # noqa: F841


def half_application(input, output):
output = ntl.cast(libdevice.lgamma(ntl.cast(input, ntl.float32)), ntl.float16) # noqa: F841


def premake(ndim, half=False, dtype=None, block_size=BLOCK_SIZE):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

application_ = half_application if half else application

return arrangement_, application_, tensors
Loading