From 013abd6b2d38dc5f69ca0ddb76e8951cff97bc48 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 02:55:54 +0000 Subject: [PATCH 1/2] Add a test case for AOT non-power-of-two innermost sizes --- tests/test_aot.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_aot.py b/tests/test_aot.py index c2ff3e6..7a8b2cd 100644 --- a/tests/test_aot.py +++ b/tests/test_aot.py @@ -343,6 +343,41 @@ def _application(input, scale, output): assert torch.allclose(output, expected) +@pytest.mark.parametrize("device", get_available_devices()) +def test_aot_with_static_non_power_of_two_innermost_sizes(device): + def _arrangement(input, output): + return input.tile((3,)), output.tile((3,)) + + def _application(input, output): + output = input # noqa: F841 + + tensors = ( + Tensor(1, dtype=ninetoothed.float32), + Tensor(1, dtype=ninetoothed.float32), + ) + + kernel_name = ( + f"static_non_power_of_two_innermost_sizes{_generate_kernel_name_suffix()}" + ) + output_dir = ninetoothed.generation.CACHE_DIR + + kernel = ninetoothed.make( + _arrangement, + _application, + tensors, + caller=device, + kernel_name=kernel_name, + output_dir=output_dir, + ) + + input = torch.randn((3,), dtype=torch.float32, device=device) + output = torch.empty_like(input) + + kernel(input, output) + + assert torch.allclose(input, output) + + def test_overflow_terms(): terms = ninetoothed.aot._overflow_terms(("input", "scale"), (2, 0)) From 7b9df90a71491034ce6e8089a91ef33bd43ccbeb Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 02:58:24 +0000 Subject: [PATCH 2/2] Fix AOT non-power-of-two innermost sizes --- src/ninetoothed/generation.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ninetoothed/generation.py b/src/ninetoothed/generation.py index 4e9e6ca..6ea1d0f 100644 --- a/src/ninetoothed/generation.py +++ b/src/ninetoothed/generation.py @@ -780,6 +780,16 @@ def _generate_offsets_and_mask(tensor, indices): @staticmethod def _generate_innermost_indices(tensor, use_power_of_2_sizes=True): class _NextPowerOfTwoMaker(ast.NodeTransformer): + def visit_Constant(self, node): + value = node.value + + if isinstance(value, int) and not isinstance(value, bool) and value > 0: + return ast.copy_location( + ast.Constant(value=1 << (value - 1).bit_length()), node + ) + + return self.generic_visit(node) + def visit_Name(self, node): name = node.id