Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ninetoothed/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,16 @@ def _generate_offsets_and_mask(tensor, indices):
@staticmethod
def _generate_innermost_indices(tensor, use_power_of_2_sizes=True):
class _NextPowerOfTwoMaker(ast.NodeTransformer):
def visit_Constant(self, node):
value = node.value

if isinstance(value, int) and not isinstance(value, bool) and value > 0:
return ast.copy_location(
ast.Constant(value=1 << (value - 1).bit_length()), node
)

return self.generic_visit(node)

def visit_Name(self, node):
name = node.id

Expand Down
35 changes: 35 additions & 0 deletions tests/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,41 @@ def _application(input, scale, output):
assert torch.allclose(output, expected)


@pytest.mark.parametrize("device", get_available_devices())
def test_aot_with_static_non_power_of_two_innermost_sizes(device):
def _arrangement(input, output):
return input.tile((3,)), output.tile((3,))

def _application(input, output):
output = input # noqa: F841

tensors = (
Tensor(1, dtype=ninetoothed.float32),
Tensor(1, dtype=ninetoothed.float32),
)

kernel_name = (
f"static_non_power_of_two_innermost_sizes{_generate_kernel_name_suffix()}"
)
output_dir = ninetoothed.generation.CACHE_DIR

kernel = ninetoothed.make(
_arrangement,
_application,
tensors,
caller=device,
kernel_name=kernel_name,
output_dir=output_dir,
)

input = torch.randn((3,), dtype=torch.float32, device=device)
output = torch.empty_like(input)

kernel(input, output)

assert torch.allclose(input, output)


def test_overflow_terms():
terms = ninetoothed.aot._overflow_terms(("input", "scale"), (2, 0))

Expand Down
Loading