From 88de0cd8beb0568366e5129b102124e577c8f7be Mon Sep 17 00:00:00 2001 From: Christopher Crouzet Date: Tue, 14 Jan 2025 11:43:47 +1300 Subject: [PATCH] Add a `len()` built-in Fixes GH-389. --- CHANGELOG.md | 1 + docs/codegen.rst | 18 ++++++++++++ docs/modules/functions.rst | 40 +++++++++++++++++++++++++++ warp/builtins.py | 55 +++++++++++++++++++++++++++++++++++++ warp/codegen.py | 51 ++++++++++++++++++++++++++++++++++ warp/native/array.h | 12 ++++++++ warp/native/mat.h | 11 ++++++++ warp/native/quat.h | 9 ++++++ warp/native/spatial.h | 11 ++++++++ warp/native/tile.h | 22 +++++++++++++++ warp/native/vec.h | 10 +++++++ warp/stubs.py | 36 ++++++++++++++++++++++++ warp/tests/test_array.py | 35 ++++++++++++++++++++++++ warp/tests/test_mat.py | 56 ++++++++++++++++++++++++++++++++++++++ warp/tests/test_quat.py | 26 ++++++++++++++++++ warp/tests/test_static.py | 16 +++++++++++ warp/tests/test_tile.py | 28 +++++++++++++++++++ warp/tests/test_vec.py | 55 +++++++++++++++++++++++++++++++++++++ warp/types.py | 5 ++++ 19 files changed, 497 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85c31ff67..1a47d3ec3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ - Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication. - Add operator overloads for `wp.struct` objects by defining `wp.func` functions ([GH-392](https://github.com/NVIDIA/warp/issues/392)). - Add `example_tile_nbody.py`, an N-Body gravitational simulation example using Warp tile primitives. +- Add a `len()` built-in to retrieve the number of elements for vec/quat/mat/arrays ([GH-389](https://github.com/NVIDIA/warp/issues/389)). ### Changed diff --git a/docs/codegen.rst b/docs/codegen.rst index b4984781c..af17a9b3e 100644 --- a/docs/codegen.rst +++ b/docs/codegen.rst @@ -446,6 +446,24 @@ The above program uses a static expression to select the right function given th [2. 0.] +Example: Static Length Query +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Python's built-in function ``len()`` can also be evaluated statically for types with fixed length, such as vectors, quaternions, and matrices, and can be wrapped into ``wp.static()`` calls to initialize other constructs: + +.. code:: python + + import warp as wp + + @wp.kernel + def my_kernel(v: wp.vec2): + m = wp.identity(n=wp.static(len(v) + 1), dtype=v.dtype) + wp.expect_eq(wp.ddot(m, m), 3.0) + + v = wp.vec2(1, 2) + wp.launch(my_kernel, 1, inputs=(v,)) + + Advanced Example: Branching Elimination with Static Loop Unrolling ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In computational simulations, it's common to apply different operations or boundary conditions based on runtime variables. However, conditional branching using runtime variables often leads to performance issues due to register pressure, as the GPU may allocate resources for all branches even if some of them are never taken. To tackle this, we can utilize static loop unrolling via ``wp.static(...)``, which helps eliminate unnecessary branching at compile-time and improve parallel execution. diff --git a/docs/modules/functions.rst b/docs/modules/functions.rst index 2d1c1a034..8ece86c88 100644 --- a/docs/modules/functions.rst +++ b/docs/modules/functions.rst @@ -1913,6 +1913,46 @@ Utility Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude +.. py:function:: len(a: Vector[Any,Scalar]) -> int + + Retrieves the number of elements in a vector. + + +.. py:function:: len(a: Quaternion[Scalar]) -> int + :noindex: + :nocontentsentry: + + Retrieves the number of elements in a quaternion. + + +.. py:function:: len(a: Matrix[Any,Any,Scalar]) -> int + :noindex: + :nocontentsentry: + + Retrieves the number of rows in a matrix. + + +.. py:function:: len(a: Transformation[Float]) -> int + :noindex: + :nocontentsentry: + + Retrieves the number of elements in a transformation. + + +.. py:function:: len(a: Array[Any]) -> int + :noindex: + :nocontentsentry: + + Retrieves the size of the first dimension in an array. + + +.. py:function:: len(a: Tile) -> int + :noindex: + :nocontentsentry: + + Retrieves the number of rows in a tile. + + Geometry diff --git a/warp/builtins.py b/warp/builtins.py index f4a4a4d4b..d74d0d642 100644 --- a/warp/builtins.py +++ b/warp/builtins.py @@ -6497,3 +6497,58 @@ def static(expr): which includes constant variables and variables captured in the current closure in which the function or kernel is implemented. """ return expr + + +add_builtin( + "len", + input_types={"a": vector(length=Any, dtype=Scalar)}, + value_type=int, + doc="Retrieves the number of elements in a vector.", + group="Utility", + export=False, +) + +add_builtin( + "len", + input_types={"a": quaternion(dtype=Scalar)}, + value_type=int, + doc="Retrieves the number of elements in a quaternion.", + group="Utility", + export=False, +) + +add_builtin( + "len", + input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)}, + value_type=int, + doc="Retrieves the number of rows in a matrix.", + group="Utility", + export=False, +) + +add_builtin( + "len", + input_types={"a": transformation(dtype=Float)}, + value_type=int, + doc="Retrieves the number of elements in a transformation.", + group="Utility", + export=False, +) + +add_builtin( + "len", + input_types={"a": array(dtype=Any)}, + value_type=int, + doc="Retrieves the size of the first dimension in an array.", + group="Utility", + export=False, +) + +add_builtin( + "len", + input_types={"a": Tile(dtype=Any, M=Any, N=Any)}, + value_type=int, + doc="Retrieves the number of rows in a tile.", + group="Utility", + export=False, +) diff --git a/warp/codegen.py b/warp/codegen.py index e339a08d1..cff299c8b 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -2882,11 +2882,62 @@ def evaluate_static_expression(adj, node) -> Tuple[Any, str]: if static_code is None: raise WarpCodegenError("Error extracting source code from wp.static() expression") + # Since this is an expression, we can enforce it to be defined on a single line. + static_code = static_code.replace("\n", "") + vars_dict = adj.get_static_evaluation_context() # add constant variables to the static call context constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None} vars_dict.update(constant_vars) + # Replace all constant `len()` expressions with their value. + if "len" in static_code: + + def eval_len(obj): + if type_is_vector(obj): + return obj._length_ + elif type_is_quaternion(obj): + return obj._length_ + elif type_is_matrix(obj): + return obj._shape_[0] + elif type_is_transformation(obj): + return obj._length_ + elif is_tile(obj): + return obj.M + + return len(obj) + + len_expr_ctx = vars_dict.copy() + constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None} + len_expr_ctx.update(constant_types) + len_expr_ctx.update({"len": eval_len}) + + # We want to replace the expression code in-place, + # so reparse it to get the correct column info. + len_value_locs = [] + expr_tree = ast.parse(static_code) + assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr) + expr_root = expr_tree.body[0].value + for expr_node in ast.walk(expr_root): + if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1: + len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset] + try: + len_value = eval(len_expr, len_expr_ctx) + except Exception: + pass + else: + len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset)) + + if len_value_locs: + new_static_code = "" + loc = 0 + for value, start, end in len_value_locs: + new_static_code += f"{static_code[loc:start]}{value}" + loc = end + + new_static_code += static_code[len_value_locs[-1][2] :] + static_code = new_static_code + try: value = eval(static_code, vars_dict) if warp.config.verbose: diff --git a/warp/native/array.h b/warp/native/array.h index ef9b2b43d..c8e0d5adc 100644 --- a/warp/native/array.h +++ b/warp/native/array.h @@ -1106,6 +1106,18 @@ inline CUDA_CALLABLE void adj_atomic_max(const A1& buf, int i, int j, int k, FP_VERIFY_ADJ_4(value, adj_value) } +template class A, typename T> +CUDA_CALLABLE inline int len(const A& a) +{ + return a.shape[0]; +} + +template class A, typename T> +CUDA_CALLABLE inline void adj_len(const A& a, A& adj_a, int& adj_ret) +{ +} + + } // namespace wp #include "fabric.h" diff --git a/warp/native/mat.h b/warp/native/mat.h index ee084d608..c02b22e31 100644 --- a/warp/native/mat.h +++ b/warp/native/mat.h @@ -1650,4 +1650,15 @@ inline CUDA_CALLABLE void adj_mat44(float m00, float m01, float m02, float m03, a33 += adj_ret.data[3][3]; } +template +CUDA_CALLABLE inline int len(const mat_t& x) +{ + return Rows; +} + +template +CUDA_CALLABLE inline void adj_len(const mat_t& x, mat_t& adj_x, const int& adj_ret) +{ +} + } // namespace wp diff --git a/warp/native/quat.h b/warp/native/quat.h index 90f9c556d..a81873f80 100644 --- a/warp/native/quat.h +++ b/warp/native/quat.h @@ -1229,6 +1229,15 @@ inline CUDA_CALLABLE quat_t quat_identity() return quat_t(Type(0), Type(0), Type(0), Type(1)); } +template +CUDA_CALLABLE inline int len(const quat_t& x) +{ + return 4; +} +template +CUDA_CALLABLE inline void adj_len(const quat_t& x, quat_t& adj_x, const int& adj_ret) +{ +} } // namespace wp diff --git a/warp/native/spatial.h b/warp/native/spatial.h index 482615369..365f64b54 100644 --- a/warp/native/spatial.h +++ b/warp/native/spatial.h @@ -400,6 +400,17 @@ CUDA_CALLABLE inline void adj_lerp(const transform_t& a, const transform_t adj_t += tensordot(b, adj_ret) - tensordot(a, adj_ret); } +template +CUDA_CALLABLE inline int len(const transform_t& t) +{ + return 7; +} + +template +CUDA_CALLABLE inline void adj_len(const transform_t& t, transform_t& adj_t, const int& adj_ret) +{ +} + template using spatial_matrix_t = mat_t<6,6,Type>; diff --git a/warp/native/tile.h b/warp/native/tile.h index 652d52063..438a0805a 100644 --- a/warp/native/tile.h +++ b/warp/native/tile.h @@ -1008,6 +1008,28 @@ void tile_register_t::print() const WP_TILE_SYNC(); } +template +inline CUDA_CALLABLE int len(const tile_register_t& t) +{ + return M; +} + +template +inline CUDA_CALLABLE void adj_len(const tile_register_t& t, const tile_register_t& a, int& adj_ret) +{ +} + +template +inline CUDA_CALLABLE int len(const tile_shared_t& t) +{ + return M; +} + +template +inline CUDA_CALLABLE void adj_len(const tile_shared_t& t, const tile_shared_t& a, int& adj_ret) +{ +} + template inline CUDA_CALLABLE void print(const tile_register_t& t) { diff --git a/warp/native/vec.h b/warp/native/vec.h index b89d38ade..3ca999f3d 100644 --- a/warp/native/vec.h +++ b/warp/native/vec.h @@ -1311,5 +1311,15 @@ inline CUDA_CALLABLE void adj_vec4(float s, float& adj_s, const vec4& adj_ret) adj_vec_t(s, adj_s, adj_ret); } +template +CUDA_CALLABLE inline int len(const vec_t& x) +{ + return Length; +} + +template +CUDA_CALLABLE inline void adj_len(const vec_t& x, vec_t& adj_x, const int& adj_ret) +{ +} } // namespace wp diff --git a/warp/stubs.py b/warp/stubs.py index c3163ab57..1583a4139 100644 --- a/warp/stubs.py +++ b/warp/stubs.py @@ -3033,3 +3033,39 @@ def static(expr: Any) -> Any: (excluding Warp arrays since they cannot be created in a Warp kernel at the moment). """ ... + + +@over +def len(a: Vector[Any, Scalar]) -> int: + """Retrieves the number of elements in a vector.""" + ... + + +@over +def len(a: Quaternion[Scalar]) -> int: + """Retrieves the number of elements in a quaternion.""" + ... + + +@over +def len(a: Matrix[Any, Any, Scalar]) -> int: + """Retrieves the number of rows in a matrix.""" + ... + + +@over +def len(a: Transformation[Float]) -> int: + """Retrieves the number of elements in a transformation.""" + ... + + +@over +def len(a: Array[Any]) -> int: + """Retrieves the size of the first dimension in an array.""" + ... + + +@over +def len(a: Tile) -> int: + """Retrieves the number of rows in a tile.""" + ... diff --git a/warp/tests/test_array.py b/warp/tests/test_array.py index c336f3f48..de627b2f2 100644 --- a/warp/tests/test_array.py +++ b/warp/tests/test_array.py @@ -2802,6 +2802,40 @@ def test_casting(test, device): assert idxs.strides == (12,) +@wp.kernel +def array_len_kernel( + a1: wp.array(dtype=int), + a2: wp.array(dtype=float, ndim=3), + out: wp.array(dtype=int), +): + length = len(a1) + wp.expect_eq(len(a1), 123) + out[0] = len(a1) + + length = len(a2) + wp.expect_eq(len(a2), 2) + out[1] = len(a2) + + +def test_array_len(test, device): + a1 = wp.zeros(123, dtype=int, device=device) + a2 = wp.zeros((2, 3, 4), dtype=float, device=device) + out = wp.empty(2, dtype=int, device=device) + wp.launch( + array_len_kernel, + dim=(1,), + inputs=( + a1, + a2, + ), + outputs=(out,), + device=device, + ) + + test.assertEqual(out.numpy()[0], 123) + test.assertEqual(out.numpy()[1], 2) + + devices = get_test_devices() @@ -2873,6 +2907,7 @@ def test_array_new_del(self): add_function_test(TestArray, "test_alloc_strides", test_alloc_strides, devices=devices) add_function_test(TestArray, "test_casting", test_casting, devices=devices) +add_function_test(TestArray, "test_array_len", test_array_len, devices=devices) try: import torch diff --git a/warp/tests/test_mat.py b/warp/tests/test_mat.py index 0bfa467d6..669b19cae 100644 --- a/warp/tests/test_mat.py +++ b/warp/tests/test_mat.py @@ -6,6 +6,7 @@ # license agreement from NVIDIA CORPORATION is strictly prohibited. import unittest +from typing import Any import numpy as np @@ -1737,6 +1738,54 @@ def test_constructors_constant_shape(): m[i, j] = float(i * j) +Mat23 = wp.mat((2, 3), dtype=wp.float16) + + +@wp.kernel +def matrix_len_kernel( + m1: wp.mat22, + m2: wp.mat((3, 3), float), + m3: wp.mat((Any, Any), float), + m4: Mat23, + out: wp.array(dtype=int), +): + length = wp.static(len(m1)) + wp.expect_eq(len(m1), 2) + out[0] = len(m1) + + length = len(m2) + wp.expect_eq(wp.static(len(m2)), 3) + out[1] = len(m2) + + length = len(m3) + wp.expect_eq(len(m3), 4) + out[2] = wp.static(len(m3)) + + length = wp.static(len(m4)) + wp.expect_eq(wp.static(len(m4)), 2) + out[3] = wp.static(len(m4)) + + foo = wp.mat22() + length = len(foo) + wp.expect_eq(len(foo), 2) + out[4] = len(foo) + + +def test_matrix_len(test, device): + m1 = wp.mat22() + m2 = wp.mat33() + m3 = wp.mat44() + m4 = Mat23() + out = wp.empty(5, dtype=int, device=device) + wp.launch(matrix_len_kernel, dim=(1,), inputs=(m1, m2, m3, m4), outputs=(out,), device=device) + + test.assertEqual(out.numpy()[0], 2) + test.assertEqual(out.numpy()[1], 3) + test.assertEqual(out.numpy()[2], 4) + test.assertEqual(out.numpy()[3], 2) + test.assertEqual(out.numpy()[4], 2) + + devices = get_test_devices() @@ -1876,6 +1925,13 @@ def test_tpl_ops_with_anon(self): dtype=dtype, ) +add_function_test( + TestMat, + "test_matrix_len", + test_matrix_len, + devices=devices, +) + if __name__ == "__main__": wp.clear_kernel_cache() diff --git a/warp/tests/test_quat.py b/warp/tests/test_quat.py index ee705d234..52c236657 100644 --- a/warp/tests/test_quat.py +++ b/warp/tests/test_quat.py @@ -2095,6 +2095,30 @@ def make_quat(*args): test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3)) +@wp.kernel +def quat_len_kernel( + q: wp.quat, + out: wp.array(dtype=int), +): + length = wp.static(len(q)) + wp.expect_eq(wp.static(len(q)), 4) + out[0] = wp.static(len(q)) + + foo = wp.quat() + length = len(foo) + wp.expect_eq(len(foo), 4) + out[1] = len(foo) + + +def test_quat_len(test, device): + q = wp.quat() + out = wp.empty(2, dtype=int, device=device) + wp.launch(quat_len_kernel, dim=(1,), inputs=(q,), outputs=(out,), device=device) + + test.assertEqual(out.numpy()[0], 4) + test.assertEqual(out.numpy()[1], 4) + + devices = get_test_devices() @@ -2203,6 +2227,8 @@ class TestQuat(unittest.TestCase): TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype ) +add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices) + if __name__ == "__main__": wp.clear_kernel_cache() diff --git a/warp/tests/test_static.py b/warp/tests/test_static.py index b8a77d041..259343399 100644 --- a/warp/tests/test_static.py +++ b/warp/tests/test_static.py @@ -536,6 +536,21 @@ def test_static_function_hash(test, _): test.assertEqual(hash1, hash3) +@wp.kernel +def static_len_query_kernel(v1: wp.vec2): + v2 = wp.vec3() + m = wp.identity(n=wp.static(len(v1) + len(v2) + 1), dtype=float) + wp.expect_eq(wp.ddot(m, m), 6.0) + + t = wp.transform_identity(float) + wp.expect_eq(wp.static(len(t)), 7) + + +def test_static_len_query(test, _): + v1 = wp.vec2() + wp.launch(static_len_query_kernel, 1, inputs=(v1,)) + + devices = get_test_devices() @@ -561,6 +576,7 @@ def test_static_python_call(self): add_function_test(TestStatic, "test_static_constant_hash", test_static_constant_hash, devices=None) add_function_test(TestStatic, "test_static_function_hash", test_static_function_hash, devices=None) +add_function_test(TestStatic, "test_static_len_query", test_static_len_query, devices=None) if __name__ == "__main__": diff --git a/warp/tests/test_tile.py b/warp/tests/test_tile.py index 5ac89eb7c..853ddaa4e 100644 --- a/warp/tests/test_tile.py +++ b/warp/tests/test_tile.py @@ -635,6 +635,33 @@ def test_tile_assign(test, device): assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy())) +@wp.kernel +def tile_len_kernel( + a: wp.array(dtype=float, ndim=2), + out: wp.array(dtype=int), +): + x = wp.tile_load(a, 0, 0, m=TILE_M, n=TILE_N) + + length = wp.static(len(x)) + wp.expect_eq(wp.static(len(x)), TILE_M) + out[0] = wp.static(len(x)) + + +def test_tile_len(test, device): + a = wp.zeros((TILE_M, TILE_N), dtype=float, device=device) + out = wp.empty(1, dtype=int, device=device) + wp.launch_tiled( + tile_len_kernel, + dim=(1,), + inputs=(a,), + outputs=(out,), + block_dim=32, + device=device, + ) + + test.assertEqual(out.numpy()[0], TILE_M) + + # #----------------------------------------- # # center of mass computation @@ -737,6 +764,7 @@ class TestTile(unittest.TestCase): add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices) add_function_test(TestTile, "test_tile_view", test_tile_view, devices=devices) add_function_test(TestTile, "test_tile_assign", test_tile_assign, devices=devices) +add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices) if __name__ == "__main__": diff --git a/warp/tests/test_vec.py b/warp/tests/test_vec.py index 811b7792c..24e0b8cd3 100644 --- a/warp/tests/test_vec.py +++ b/warp/tests/test_vec.py @@ -6,6 +6,7 @@ # license agreement from NVIDIA CORPORATION is strictly prohibited. import unittest +from typing import Any import numpy as np @@ -1240,6 +1241,54 @@ def test_constructors_constant_length(): v[i] = float(i) +Vec123 = wp.vec(123, dtype=wp.float16) + + +@wp.kernel +def vector_len_kernel( + v1: wp.vec2, + v2: wp.vec(3, float), + v3: wp.vec(Any, float), + v4: Vec123, + out: wp.array(dtype=int), +): + length = wp.static(len(v1)) + wp.expect_eq(len(v1), 2) + out[0] = len(v1) + + length = len(v2) + wp.expect_eq(wp.static(len(v2)), 3) + out[1] = len(v2) + + length = len(v3) + wp.expect_eq(len(v3), 4) + out[2] = wp.static(len(v3)) + + length = wp.static(len(v4)) + wp.expect_eq(wp.static(len(v4)), 123) + out[3] = wp.static(len(v4)) + + foo = wp.vec2() + length = len(foo) + wp.expect_eq(len(foo), 2) + out[4] = len(foo) + + +def test_vector_len(test, device): + v1 = wp.vec2() + v2 = wp.vec3() + v3 = wp.vec4() + v4 = Vec123() + out = wp.empty(5, dtype=int, device=device) + wp.launch(vector_len_kernel, dim=(1,), inputs=(v1, v2, v3, v4), outputs=(out,), device=device) + + test.assertEqual(out.numpy()[0], 2) + test.assertEqual(out.numpy()[1], 3) + test.assertEqual(out.numpy()[2], 4) + test.assertEqual(out.numpy()[3], 123) + test.assertEqual(out.numpy()[4], 2) + + devices = get_test_devices() @@ -1350,6 +1399,12 @@ def test_tpl_ops_with_anon(self): test_tpl_constructor_error_numeric_args_mismatch, devices=devices, ) +add_function_test( + TestVec, + "test_vector_len", + test_vector_len, + devices=devices, +) if __name__ == "__main__": diff --git a/warp/types.py b/warp/types.py index a38de3e66..fd7075196 100644 --- a/warp/types.py +++ b/warp/types.py @@ -1360,6 +1360,11 @@ def type_is_matrix(t): return getattr(t, "_wp_generic_type_hint_", None) is Matrix +# returns True if the passed *type* is a transformation +def type_is_transformation(t): + return getattr(t, "_wp_generic_type_hint_", None) is Transformation + + value_types = (int, float, builtins.bool) + scalar_types