Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
bmm,
clamp,
conv2d,
copysign,
cos,
div,
dropout,
Expand All @@ -20,14 +21,18 @@
isinf,
isnan,
layer_norm,
lcm,
le,
lgamma,
lt,
max_pool2d,
mm,
mul,
ne,
neg,
nextafter,
pow,
rad2deg,
relu,
rms_norm,
rotary_position_embedding,
Expand All @@ -52,6 +57,7 @@
"bmm",
"clamp",
"conv2d",
"copysign",
"cos",
"div",
"dropout",
Expand All @@ -63,14 +69,18 @@
"isinf",
"isnan",
"layer_norm",
"lcm",
"le",
"lgamma",
"lt",
"max_pool2d",
"mm",
"mul",
"ne",
"neg",
"nextafter",
"pow",
"rad2deg",
"relu",
"rms_norm",
"rotary_position_embedding",
Expand Down
22 changes: 22 additions & 0 deletions src/ntops/kernels/copysign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import functools

from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
output = libdevice.copysign(input, other) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
33 changes: 33 additions & 0 deletions src/ntops/kernels/lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
a = ntl.abs(input)
b = ntl.abs(other)
for _ in range(32):
b_is_zero = b == 0
b_safe = ntl.where(b_is_zero, 1, b)
a_mod_b = a % b_safe
new_a = ntl.where(b_is_zero, a, b)
new_b = ntl.where(b_is_zero, b, a_mod_b)
a, b = new_a, new_b
a_is_zero = a == 0
a_safe = ntl.where(a_is_zero, 1, a)
output = ntl.abs(input // a_safe * other) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/lgamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = libdevice.lgamma(input) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
22 changes: 22 additions & 0 deletions src/ntops/kernels/nextafter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import functools

from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
output = libdevice.nextafter(input, other) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
17 changes: 17 additions & 0 deletions src/ntops/kernels/rad2deg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import functools

from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = input * 57.29577951308232 # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ntops.torch.bmm import bmm
from ntops.torch.clamp import clamp
from ntops.torch.conv2d import conv2d
from ntops.torch.copysign import copysign
from ntops.torch.cos import cos
from ntops.torch.div import div
from ntops.torch.dropout import dropout
Expand All @@ -19,15 +20,19 @@
from ntops.torch.isinf import isinf
from ntops.torch.isnan import isnan
from ntops.torch.layer_norm import layer_norm
from ntops.torch.lcm import lcm
from ntops.torch.le import le
from ntops.torch.lgamma import lgamma
from ntops.torch.lt import lt
from ntops.torch.matmul import matmul
from ntops.torch.max_pool2d import max_pool2d
from ntops.torch.mm import mm
from ntops.torch.mul import mul
from ntops.torch.ne import ne
from ntops.torch.neg import neg
from ntops.torch.nextafter import nextafter
from ntops.torch.pow import pow
from ntops.torch.rad2deg import rad2deg
from ntops.torch.relu import relu
from ntops.torch.rms_norm import rms_norm
from ntops.torch.rotary_position_embedding import rotary_position_embedding
Expand All @@ -51,6 +56,7 @@
"bmm",
"clamp",
"conv2d",
"copysign",
"cos",
"div",
"dropout",
Expand All @@ -62,15 +68,19 @@
"isinf",
"isnan",
"layer_norm",
"lcm",
"le",
"lgamma",
"lt",
"matmul",
"max_pool2d",
"mm",
"mul",
"ne",
"neg",
"nextafter",
"pow",
"rad2deg",
"relu",
"rms_norm",
"rotary_position_embedding",
Expand Down
15 changes: 15 additions & 0 deletions src/ntops/torch/copysign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def copysign(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.copysign.premake, input.ndim)

kernel(input, other, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def lcm(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.lcm.premake, input.ndim)

kernel(input, other, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/lgamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def lgamma(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.lgamma.premake, input.ndim)

kernel(input, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/nextafter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def nextafter(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.nextafter.premake, input.ndim)

kernel(input, other, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/rad2deg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def rad2deg(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.rad2deg.premake, input.ndim)

kernel(input, out)

return out
21 changes: 21 additions & 0 deletions tests/test_copysign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch

import ntops
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_copysign(shape, dtype, device, rtol, atol):
# TODO: Test for `float16` later.
if dtype is torch.float16:
return
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.copysign(input, other)
reference_output = torch.copysign(input, other)

assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
27 changes: 27 additions & 0 deletions tests/test_lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

import ntops
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments(False))
def test_lcm(shape, dtype, device, rtol, atol):
if dtype == torch.bool:
input = torch.randint(1, 10, size=shape, dtype=torch.int32, device=device)
other = torch.randint(1, 10, size=shape, dtype=torch.int32, device=device)
else:
upper_bound = 20
input = torch.randint(
1, upper_bound, size=shape, dtype=dtype, device=device
)
other = torch.randint(
1, upper_bound, size=shape, dtype=dtype, device=device
)

ninetoothed_output = ntops.torch.lcm(input, other)
reference_output = torch.lcm(input, other)

assert torch.equal(ninetoothed_output, reference_output)
21 changes: 21 additions & 0 deletions tests/test_lgamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch

import ntops
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_lgamma(shape, dtype, device, rtol, atol):
# TODO: Test for `float16` later.
if dtype is torch.float16:
return
# Use positive values since lgamma is defined for positive inputs.
input = torch.rand(shape, dtype=dtype, device=device) * 5 + 0.1

ninetoothed_output = ntops.torch.lgamma(input)
reference_output = torch.lgamma(input)

assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
Loading