Skip to content

Commit

Permalink
Merge branch 'ccrouzet/gh-389-len' into 'main'
Browse files Browse the repository at this point in the history
Add a `len()` Built-in

See merge request omniverse/warp!982
  • Loading branch information
christophercrouzet committed Jan 20, 2025
2 parents c9c05a6 + 88de0cd commit ca2a2f7
Show file tree
Hide file tree
Showing 19 changed files with 497 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication.
- Add operator overloads for `wp.struct` objects by defining `wp.func` functions ([GH-392](https://github.com/NVIDIA/warp/issues/392)).
- Add `example_tile_nbody.py`, an N-Body gravitational simulation example using Warp tile primitives.
- Add a `len()` built-in to retrieve the number of elements for vec/quat/mat/arrays ([GH-389](https://github.com/NVIDIA/warp/issues/389)).

### Changed

Expand Down
18 changes: 18 additions & 0 deletions docs/codegen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,24 @@ The above program uses a static expression to select the right function given th
[2. 0.]
Example: Static Length Query
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Python's built-in function ``len()`` can also be evaluated statically for types with fixed length, such as vectors, quaternions, and matrices, and can be wrapped into ``wp.static()`` calls to initialize other constructs:

.. code:: python
import warp as wp
@wp.kernel
def my_kernel(v: wp.vec2):
m = wp.identity(n=wp.static(len(v) + 1), dtype=v.dtype)
wp.expect_eq(wp.ddot(m, m), 3.0)
v = wp.vec2(1, 2)
wp.launch(my_kernel, 1, inputs=(v,))
Advanced Example: Branching Elimination with Static Loop Unrolling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In computational simulations, it's common to apply different operations or boundary conditions based on runtime variables. However, conditional branching using runtime variables often leads to performance issues due to register pressure, as the GPU may allocate resources for all branches even if some of them are never taken. To tackle this, we can utilize static loop unrolling via ``wp.static(...)``, which helps eliminate unnecessary branching at compile-time and improve parallel execution.
Expand Down
40 changes: 40 additions & 0 deletions docs/modules/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,46 @@ Utility
Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude


.. py:function:: len(a: Vector[Any,Scalar]) -> int
Retrieves the number of elements in a vector.


.. py:function:: len(a: Quaternion[Scalar]) -> int
:noindex:
:nocontentsentry:

Retrieves the number of elements in a quaternion.


.. py:function:: len(a: Matrix[Any,Any,Scalar]) -> int
:noindex:
:nocontentsentry:

Retrieves the number of rows in a matrix.


.. py:function:: len(a: Transformation[Float]) -> int
:noindex:
:nocontentsentry:

Retrieves the number of elements in a transformation.


.. py:function:: len(a: Array[Any]) -> int
:noindex:
:nocontentsentry:

Retrieves the size of the first dimension in an array.


.. py:function:: len(a: Tile) -> int
:noindex:
:nocontentsentry:

Retrieves the number of rows in a tile.




Geometry
Expand Down
55 changes: 55 additions & 0 deletions warp/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6497,3 +6497,58 @@ def static(expr):
which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
"""
return expr


add_builtin(
"len",
input_types={"a": vector(length=Any, dtype=Scalar)},
value_type=int,
doc="Retrieves the number of elements in a vector.",
group="Utility",
export=False,
)

add_builtin(
"len",
input_types={"a": quaternion(dtype=Scalar)},
value_type=int,
doc="Retrieves the number of elements in a quaternion.",
group="Utility",
export=False,
)

add_builtin(
"len",
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
value_type=int,
doc="Retrieves the number of rows in a matrix.",
group="Utility",
export=False,
)

add_builtin(
"len",
input_types={"a": transformation(dtype=Float)},
value_type=int,
doc="Retrieves the number of elements in a transformation.",
group="Utility",
export=False,
)

add_builtin(
"len",
input_types={"a": array(dtype=Any)},
value_type=int,
doc="Retrieves the size of the first dimension in an array.",
group="Utility",
export=False,
)

add_builtin(
"len",
input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
value_type=int,
doc="Retrieves the number of rows in a tile.",
group="Utility",
export=False,
)
51 changes: 51 additions & 0 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2882,11 +2882,62 @@ def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
if static_code is None:
raise WarpCodegenError("Error extracting source code from wp.static() expression")

# Since this is an expression, we can enforce it to be defined on a single line.
static_code = static_code.replace("\n", "")

vars_dict = adj.get_static_evaluation_context()
# add constant variables to the static call context
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
vars_dict.update(constant_vars)

# Replace all constant `len()` expressions with their value.
if "len" in static_code:

def eval_len(obj):
if type_is_vector(obj):
return obj._length_
elif type_is_quaternion(obj):
return obj._length_
elif type_is_matrix(obj):
return obj._shape_[0]
elif type_is_transformation(obj):
return obj._length_
elif is_tile(obj):
return obj.M

return len(obj)

len_expr_ctx = vars_dict.copy()
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
len_expr_ctx.update(constant_types)
len_expr_ctx.update({"len": eval_len})

# We want to replace the expression code in-place,
# so reparse it to get the correct column info.
len_value_locs = []
expr_tree = ast.parse(static_code)
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
expr_root = expr_tree.body[0].value
for expr_node in ast.walk(expr_root):
if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
try:
len_value = eval(len_expr, len_expr_ctx)
except Exception:
pass
else:
len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))

if len_value_locs:
new_static_code = ""
loc = 0
for value, start, end in len_value_locs:
new_static_code += f"{static_code[loc:start]}{value}"
loc = end

new_static_code += static_code[len_value_locs[-1][2] :]
static_code = new_static_code

try:
value = eval(static_code, vars_dict)
if warp.config.verbose:
Expand Down
12 changes: 12 additions & 0 deletions warp/native/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,18 @@ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k,
FP_VERIFY_ADJ_4(value, adj_value)
}

template<template<typename> class A, typename T>
CUDA_CALLABLE inline int len(const A<T>& a)
{
return a.shape[0];
}

template<template<typename> class A, typename T>
CUDA_CALLABLE inline void adj_len(const A<T>& a, A<T>& adj_a, int& adj_ret)
{
}


} // namespace wp

#include "fabric.h"
11 changes: 11 additions & 0 deletions warp/native/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -1650,4 +1650,15 @@ inline CUDA_CALLABLE void adj_mat44(float m00, float m01, float m02, float m03,
a33 += adj_ret.data[3][3];
}

template<unsigned Rows, unsigned Cols, typename Type>
CUDA_CALLABLE inline int len(const mat_t<Rows,Cols,Type>& x)
{
return Rows;
}

template<unsigned Rows, unsigned Cols, typename Type>
CUDA_CALLABLE inline void adj_len(const mat_t<Rows,Cols,Type>& x, mat_t<Rows,Cols,Type>& adj_x, const int& adj_ret)
{
}

} // namespace wp
9 changes: 9 additions & 0 deletions warp/native/quat.h
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,15 @@ inline CUDA_CALLABLE quat_t<Type> quat_identity()
return quat_t<Type>(Type(0), Type(0), Type(0), Type(1));
}

template<typename Type>
CUDA_CALLABLE inline int len(const quat_t<Type>& x)
{
return 4;
}

template<typename Type>
CUDA_CALLABLE inline void adj_len(const quat_t<Type>& x, quat_t<Type>& adj_x, const int& adj_ret)
{
}

} // namespace wp
11 changes: 11 additions & 0 deletions warp/native/spatial.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,17 @@ CUDA_CALLABLE inline void adj_lerp(const transform_t<Type>& a, const transform_t
adj_t += tensordot(b, adj_ret) - tensordot(a, adj_ret);
}

template<typename Type>
CUDA_CALLABLE inline int len(const transform_t<Type>& t)
{
return 7;
}

template<typename Type>
CUDA_CALLABLE inline void adj_len(const transform_t<Type>& t, transform_t<Type>& adj_t, const int& adj_ret)
{
}

template<typename Type>
using spatial_matrix_t = mat_t<6,6,Type>;

Expand Down
22 changes: 22 additions & 0 deletions warp/native/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,28 @@ void tile_register_t<T, M, N>::print() const
WP_TILE_SYNC();
}

template <typename T, int M, int N>
inline CUDA_CALLABLE int len(const tile_register_t<T, M, N>& t)
{
return M;
}

template <typename T, int M, int N>
inline CUDA_CALLABLE void adj_len(const tile_register_t<T, M, N>& t, const tile_register_t<T, M, N>& a, int& adj_ret)
{
}

template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
inline CUDA_CALLABLE int len(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t)
{
return M;
}

template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
inline CUDA_CALLABLE void adj_len(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t, const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& a, int& adj_ret)
{
}

template <typename T, int M, int N>
inline CUDA_CALLABLE void print(const tile_register_t<T, M, N>& t)
{
Expand Down
10 changes: 10 additions & 0 deletions warp/native/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -1311,5 +1311,15 @@ inline CUDA_CALLABLE void adj_vec4(float s, float& adj_s, const vec4& adj_ret)
adj_vec_t(s, adj_s, adj_ret);
}

template<unsigned Length, typename Type>
CUDA_CALLABLE inline int len(const vec_t<Length, Type>& x)
{
return Length;
}

template<unsigned Length, typename Type>
CUDA_CALLABLE inline void adj_len(const vec_t<Length, Type>& x, vec_t<Length, Type>& adj_x, const int& adj_ret)
{
}

} // namespace wp
36 changes: 36 additions & 0 deletions warp/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3033,3 +3033,39 @@ def static(expr: Any) -> Any:
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).
"""
...


@over
def len(a: Vector[Any, Scalar]) -> int:
"""Retrieves the number of elements in a vector."""
...


@over
def len(a: Quaternion[Scalar]) -> int:
"""Retrieves the number of elements in a quaternion."""
...


@over
def len(a: Matrix[Any, Any, Scalar]) -> int:
"""Retrieves the number of rows in a matrix."""
...


@over
def len(a: Transformation[Float]) -> int:
"""Retrieves the number of elements in a transformation."""
...


@over
def len(a: Array[Any]) -> int:
"""Retrieves the size of the first dimension in an array."""
...


@over
def len(a: Tile) -> int:
"""Retrieves the number of rows in a tile."""
...
Loading

0 comments on commit ca2a2f7

Please sign in to comment.