diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6bc7b1214..6ea2e2a00 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,6 +20,8 @@
- Add support for wp.tile_load() where the source array shape is not a multiple of the tile dimension, out of bounds reads will be zero-filled
- Add support for higher dimensional (up to 4d) tile shapes and memory operations
- Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication.
+- Add functions `norm_l1`, `norm_l2`, `norm_huber`, `norm_pseudo_huber`, `smooth_normalize` for vector types to a new `math.py` module.
+- `SemiImplicitIntegrator` and `FeatherstoneIntegrator` now have an optional `friction_smoothing` constructor argument (defaults to 1.0) that controls softness of the friction norm computation.
- Add operator overloads for `wp.struct` objects by defining `wp.func` functions ([GH-392](https://github.com/NVIDIA/warp/issues/392)).
- Add `example_tile_nbody.py`, an N-Body gravitational simulation example using Warp tile primitives.
- Add a `len()` built-in to retrieve the number of elements for vec/quat/mat/arrays ([GH-389](https://github.com/NVIDIA/warp/issues/389)).
@@ -34,12 +36,15 @@
- `wp.Bvh` constructor now supports multiple construction methods, including `SAH` ( Surface Area Heuristics), `Median` and `LBVH`.
- Avoid recompilation of modules when changing `block_dim`.
- Improve memory consumption, compilation and runtime performance when using in-place vector/matrix assignments in kernels that have `enable_backward` set to False ([GH-332](https://github.com/NVIDIA/warp/issues/332)).
+- `update_vbo_transforms` kernel launches in OpenGLRenderer are no longer recorded on the tape.
- Fix the `len()` operator returning the total size of a matrix instead of its first dimension.
- Change exception types and error messages thrown by tile functions for improved consistency.
- Vector/Matrix/Quaternion component `+=` and `-=` operations compile and run faster in the backward pass.
### Fixed
+- Fix gradient instability in rigid-body contact handling for `SemiImplicitIntegrator`, `FeatherstoneIntegrator` ([GH-349](https://github.com/NVIDIA/warp/issues/349)).
+- Fix overload resolution of generic Warp functions with default arguments.
- Fix autodiff Jacobian computation in `warp.autograd.jacobian_ad` where in some cases gradients were not zero-ed out properly.
- Fix plotting issues in `warp.autograd.jacobian_plot`.
- Fix errors during graph capture caused by module unloading ([GH-401](https://github.com/NVIDIA/warp/issues/401)).
diff --git a/build_docs.py b/build_docs.py
index d9f7b819b..40d3e4187 100644
--- a/build_docs.py
+++ b/build_docs.py
@@ -10,6 +10,7 @@
import shutil
import subprocess
+import warp # ensure all API functions are loaded # noqa: F401
from warp.context import export_functions_rst, export_stubs
parser = argparse.ArgumentParser(description="Warp Sphinx Documentation Builder")
diff --git a/docs/img/norm_huber.svg b/docs/img/norm_huber.svg
new file mode 100644
index 000000000..7ee76ca94
--- /dev/null
+++ b/docs/img/norm_huber.svg
@@ -0,0 +1,2628 @@
+
+
diff --git a/docs/img/norm_pseudo_huber.svg b/docs/img/norm_pseudo_huber.svg
new file mode 100644
index 000000000..08d8d5900
--- /dev/null
+++ b/docs/img/norm_pseudo_huber.svg
@@ -0,0 +1,2753 @@
+
+
diff --git a/docs/modules/functions.rst b/docs/modules/functions.rst
index 71d855aaf..e8a0ed335 100644
--- a/docs/modules/functions.rst
+++ b/docs/modules/functions.rst
@@ -590,6 +590,11 @@ Vector Math
while the corresponding eigenvalues are returned in ``d``.
+.. autofunction:: warp.math.norm_l1
+.. autofunction:: warp.math.norm_l2
+.. autofunction:: warp.math.norm_huber
+.. autofunction:: warp.math.norm_pseudo_huber
+.. autofunction:: warp.math.smooth_normalize
Quaternion Math
diff --git a/warp/__init__.py b/warp/__init__.py
index 71a9f86f0..0ee20812a 100644
--- a/warp/__init__.py
+++ b/warp/__init__.py
@@ -118,6 +118,8 @@
from . import builtins
from warp.builtins import static
+from warp.math import *
+
import warp.config as config
__version__ = config.version
diff --git a/warp/context.py b/warp/context.py
index 0b56146ad..18d177925 100644
--- a/warp/context.py
+++ b/warp/context.py
@@ -395,7 +395,8 @@ def get_overload(self, arg_types, kwarg_types):
if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
continue
- if len(f.input_types) != len(arg_types):
+ acceptable_arg_num = len(f.input_types) - len(f.defaults) <= len(arg_types) <= len(f.input_types)
+ if not acceptable_arg_num:
continue
# try to match the given types to the function template types
@@ -412,6 +413,10 @@ def get_overload(self, arg_types, kwarg_types):
arg_names = f.input_types.keys()
overload_annotations = dict(zip(arg_names, arg_types))
+ # add defaults
+ for k, d in f.defaults.items():
+ if k not in overload_annotations:
+ overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
ovl = shallowcopy(f)
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
@@ -755,8 +760,15 @@ def func(f):
scope_locals = inspect.currentframe().f_back.f_locals
m = get_module(f.__module__)
+ doc = getattr(f, "__doc__", "") or ""
Function(
- func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
+ func=f,
+ key=name,
+ namespace="",
+ module=m,
+ value_func=None,
+ scope_locals=scope_locals,
+ doc=doc.strip(),
) # value_type not known yet, will be inferred during Adjoint.build()
# use the top of the list of overloads for this key
@@ -1061,7 +1073,8 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
raise RuntimeError("wp.overload() called with invalid argument!")
-builtin_functions = {}
+# native functions that are part of the Warp API
+builtin_functions: Dict[str, Function] = {}
def get_generic_vtypes():
@@ -1330,6 +1343,28 @@ def initializer_list_func(args, return_type):
setattr(warp, key, func)
+def register_api_function(
+ function: Function,
+ group: str = "Other",
+ hidden=False,
+):
+ """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
+
+ Args:
+ function (Function): Warp function to be registered.
+ group (str): Classification used for the documentation.
+ input_types (Mapping[str, Any]): Signature of the user-facing function.
+ Variadic arguments are supported by prefixing the parameter names
+ with asterisks as in `*args` and `**kwargs`. Generic arguments are
+ supported with types such as `Any`, `Float`, `Scalar`, etc.
+ value_type (Any): Type returned by the function.
+ hidden (bool): Whether to add that function into the documentation.
+ """
+ function.group = group
+ function.hidden = hidden
+ builtin_functions[function.key] = function
+
+
# global dictionary of modules
user_modules = {}
@@ -6182,14 +6217,19 @@ def export_functions_rst(file): # pragma: no cover
# build dictionary of all functions by group
groups = {}
- for _k, f in builtin_functions.items():
+ functions = list(builtin_functions.values())
+
+ for f in functions:
# build dict of groups
if f.group not in groups:
groups[f.group] = []
- # append all overloads to the group
- for o in f.overloads:
- groups[f.group].append(o)
+ if hasattr(f, "overloads"):
+ # append all overloads to the group
+ for o in f.overloads:
+ groups[f.group].append(o)
+ else:
+ groups[f.group].append(f)
# Keep track of what function and query types have been written
written_functions = set()
@@ -6209,6 +6249,10 @@ def export_functions_rst(file): # pragma: no cover
print("---------------", file=file)
for f in g:
+ if f.func:
+ # f is a Warp function written in Python, we can use autofunction
+ print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
+ continue
for f_prefix, query_type in query_types:
if f.key.startswith(f_prefix) and query_type not in written_query_types:
print(f".. autoclass:: {query_type}", file=file)
@@ -6266,24 +6310,32 @@ def export_stubs(file): # pragma: no cover
print(header, file=file)
print(file=file)
- for k, g in builtin_functions.items():
- for f in g.overloads:
- args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
+ def add_stub(f):
+ args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
- return_str = ""
+ return_str = ""
- if f.hidden: # or f.generic:
- continue
+ if f.hidden: # or f.generic:
+ return
+ return_type = f.value_type
+ if f.value_func:
return_type = f.value_func(None, None)
- if return_type:
- return_str = " -> " + type_str(return_type)
-
- print("@over", file=file)
- print(f"def {f.key}({args}){return_str}:", file=file)
- print(f' """{f.doc}', file=file)
- print(' """', file=file)
- print(" ...\n\n", file=file)
+ if return_type:
+ return_str = " -> " + type_str(return_type)
+
+ print("@over", file=file)
+ print(f"def {f.key}({args}){return_str}:", file=file)
+ print(f' """{f.doc}', file=file)
+ print(' """', file=file)
+ print(" ...\n\n", file=file)
+
+ for g in builtin_functions.values():
+ if hasattr(g, "overloads"):
+ for f in g.overloads:
+ add_stub(f)
+ else:
+ add_stub(g)
def export_builtins(file: io.TextIOBase): # pragma: no cover
@@ -6309,6 +6361,8 @@ def ctype_ret_str(t):
file.write('extern "C" {\n\n')
for k, g in builtin_functions.items():
+ if not hasattr(g, "overloads"):
+ continue
for f in g.overloads:
if not f.export or f.generic:
continue
diff --git a/warp/math.py b/warp/math.py
new file mode 100644
index 000000000..98bc09aec
--- /dev/null
+++ b/warp/math.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from typing import Any
+
+import warp as wp
+
+"""
+Vector norm functions
+"""
+
+__all__ = [
+ "norm_l1",
+ "norm_l2",
+ "norm_huber",
+ "norm_pseudo_huber",
+ "smooth_normalize",
+]
+
+
+@wp.func
+def norm_l1(v: Any):
+ """
+ Computes the L1 norm of a vector v.
+
+ .. math:: \\|v\\|_1 = \\sum_i |v_i|
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the L1 norm of.
+
+ Returns:
+ float: The L1 norm of the vector.
+ """
+ n = float(0.0)
+ for i in range(len(v)):
+ n += wp.abs(v[i])
+ return n
+
+
+@wp.func
+def norm_l2(v: Any):
+ """
+ Computes the L2 norm of a vector v.
+
+ .. math:: \\|v\\|_2 = \\sqrt{\\sum_i v_i^2}
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the L2 norm of.
+
+ Returns:
+ float: The L2 norm of the vector.
+ """
+ return wp.length(v)
+
+
+@wp.func
+def norm_huber(v: Any, delta: float = 1.0):
+ """
+ Computes the Huber norm of a vector v with a given delta.
+
+ .. math::
+ H(v) = \\begin{cases} \\frac{1}{2} \\|v\\|^2 & \\text{if } \\|v\\| \\leq \\delta \\\\ \\delta(\\|v\\| - \\frac{1}{2}\\delta) & \\text{otherwise} \\end{cases}
+
+ .. image:: /img/norm_huber.svg
+ :align: center
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the Huber norm of.
+ delta (float): The threshold value, defaults to 1.0.
+
+ Returns:
+ float: The Huber norm of the vector.
+ """
+ a = wp.dot(v, v)
+ if a <= delta * delta:
+ return 0.5 * a
+ return delta * (wp.sqrt(a) - 0.5 * delta)
+
+
+@wp.func
+def norm_pseudo_huber(v: Any, delta: float = 1.0):
+ """
+ Computes the "pseudo" Huber norm of a vector v with a given delta.
+
+ .. math::
+ H^\\prime(v) = \\delta \\sqrt{1 + \\frac{\\|v\\|^2}{\\delta^2}}
+
+ .. image:: /img/norm_pseudo_huber.svg
+ :align: center
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the Huber norm of.
+ delta (float): The threshold value, defaults to 1.0.
+
+ Returns:
+ float: The Huber norm of the vector.
+ """
+ a = wp.dot(v, v)
+ return delta * wp.sqrt(1.0 + a / (delta * delta))
+
+
+@wp.func
+def smooth_normalize(v: Any, delta: float = 1.0):
+ """
+ Normalizes a vector using the pseudo-Huber norm.
+
+ See :func:`norm_pseudo_huber`.
+
+ .. math::
+ \\frac{v}{H^\\prime(v)}
+
+ Args:
+ v (Vector[Any,Float]): The vector to normalize.
+ delta (float): The threshold value, defaults to 1.0.
+
+ Returns:
+ Vector[Any,Float]: The normalized vector.
+ """
+ return v / norm_pseudo_huber(v, delta)
+
+
+# register API functions so they appear in the documentation
+
+wp.context.register_api_function(
+ norm_l1,
+ group="Vector Math",
+)
+wp.context.register_api_function(
+ norm_l2,
+ group="Vector Math",
+)
+wp.context.register_api_function(
+ norm_huber,
+ group="Vector Math",
+)
+wp.context.register_api_function(
+ norm_pseudo_huber,
+ group="Vector Math",
+)
+wp.context.register_api_function(
+ smooth_normalize,
+ group="Vector Math",
+)
diff --git a/warp/render/render_opengl.py b/warp/render/render_opengl.py
index 9a237a016..73bc6a99d 100644
--- a/warp/render/render_opengl.py
+++ b/warp/render/render_opengl.py
@@ -847,6 +847,7 @@ def allocate_instances(self, positions, rotations=None, colors1=None, colors2=No
vbo_transforms,
],
device=self.device,
+ record_tape=False,
)
vbo_transforms = vbo_transforms.numpy()
@@ -908,6 +909,7 @@ def update_instances(self, transforms: wp.array = None, scalings: wp.array = Non
vbo_transforms,
],
device=self.device,
+ record_tape=False,
)
self._instance_transform_cuda_buffer.unmap()
@@ -2507,6 +2509,7 @@ def update_body_transforms(self, body_tf: wp.array):
vbo_transforms,
],
device=self._device,
+ record_tape=False,
)
self._instance_transform_cuda_buffer.unmap()
diff --git a/warp/sim/__init__.py b/warp/sim/__init__.py
index 28374f74f..08e5812ce 100644
--- a/warp/sim/__init__.py
+++ b/warp/sim/__init__.py
@@ -50,4 +50,9 @@
ModelShapeMaterials,
State,
)
-from .utils import load_mesh, quat_from_euler, quat_to_euler, velocity_at_point
+from .utils import (
+ load_mesh,
+ quat_from_euler,
+ quat_to_euler,
+ velocity_at_point,
+)
diff --git a/warp/sim/collide.py b/warp/sim/collide.py
index b35757444..3bb923d9d 100644
--- a/warp/sim/collide.py
+++ b/warp/sim/collide.py
@@ -1567,9 +1567,9 @@ def collide(model, state, edge_sdf_iter: int = 10, iterate_mesh_vertices: bool =
# generate soft contacts for particles and shapes except ground plane (last shape)
if model.particle_count and model.shape_count > 1:
if requires_grad:
- model.soft_contact_body_pos = wp.clone(model.soft_contact_body_pos)
- model.soft_contact_body_vel = wp.clone(model.soft_contact_body_vel)
- model.soft_contact_normal = wp.clone(model.soft_contact_normal)
+ model.soft_contact_body_pos = wp.empty_like(model.soft_contact_body_pos)
+ model.soft_contact_body_vel = wp.empty_like(model.soft_contact_body_vel)
+ model.soft_contact_normal = wp.empty_like(model.soft_contact_normal)
# clear old count
model.soft_contact_count.zero_()
wp.launch(
@@ -1666,12 +1666,12 @@ def collide(model, state, edge_sdf_iter: int = 10, iterate_mesh_vertices: bool =
if model.shape_contact_pair_count or model.ground and model.shape_ground_contact_pair_count:
if requires_grad:
- model.rigid_contact_point0 = wp.clone(model.rigid_contact_point0)
- model.rigid_contact_point1 = wp.clone(model.rigid_contact_point1)
- model.rigid_contact_offset0 = wp.clone(model.rigid_contact_offset0)
- model.rigid_contact_offset1 = wp.clone(model.rigid_contact_offset1)
- model.rigid_contact_normal = wp.clone(model.rigid_contact_normal)
- model.rigid_contact_thickness = wp.clone(model.rigid_contact_thickness)
+ model.rigid_contact_point0 = wp.empty_like(model.rigid_contact_point0)
+ model.rigid_contact_point1 = wp.empty_like(model.rigid_contact_point1)
+ model.rigid_contact_offset0 = wp.empty_like(model.rigid_contact_offset0)
+ model.rigid_contact_offset1 = wp.empty_like(model.rigid_contact_offset1)
+ model.rigid_contact_normal = wp.empty_like(model.rigid_contact_normal)
+ model.rigid_contact_thickness = wp.empty_like(model.rigid_contact_thickness)
model.rigid_contact_count = wp.zeros_like(model.rigid_contact_count)
model.rigid_contact_tids = wp.zeros_like(model.rigid_contact_tids)
model.rigid_contact_shape0 = wp.empty_like(model.rigid_contact_shape0)
diff --git a/warp/sim/integrator_euler.py b/warp/sim/integrator_euler.py
index d05a93999..84319083d 100644
--- a/warp/sim/integrator_euler.py
+++ b/warp/sim/integrator_euler.py
@@ -871,6 +871,7 @@ def eval_rigid_contacts(
contact_shape0: wp.array(dtype=int),
contact_shape1: wp.array(dtype=int),
force_in_world_frame: bool,
+ friction_smoothing: float,
# outputs
body_f: wp.array(dtype=wp.spatial_vector),
):
@@ -924,6 +925,8 @@ def eval_rigid_contacts(
n = contact_normal[tid]
bx_a = contact_point0[tid]
bx_b = contact_point1[tid]
+ r_a = wp.vec3(0.0)
+ r_b = wp.vec3(0.0)
if body_a >= 0:
X_wb_a = body_q[body_a]
X_com_a = body_com[body_a]
@@ -990,12 +993,16 @@ def eval_rigid_contacts(
# ft = wp.vec3(vx, 0.0, vz)
# Coulomb friction (smooth, but gradients are numerically unstable around |vt| = 0)
- # ft = wp.normalize(vt)*wp.min(kf*wp.length(vt), abs(mu*d*ke))
ft = wp.vec3(0.0)
if d < 0.0:
- ft = wp.normalize(vt) * wp.min(kf * wp.length(vt), -mu * (fn + fd))
+ # use a smooth vector norm to avoid gradient instability at/around zero velocity
+ vs = wp.norm_huber(vt, delta=friction_smoothing)
+ if vs > 0.0:
+ fr = vt / vs
+ ft = fr * wp.min(kf * vs, -mu * (fn + fd))
f_total = n * (fn + fd) + ft
+ # f_total = n * (fn + fd)
# f_total = n * fn
if body_a >= 0:
@@ -1761,7 +1768,7 @@ def eval_tetrahedral_forces(model: Model, state: State, control: Control, partic
)
-def eval_body_contact_forces(model: Model, state: State, particle_f: wp.array):
+def eval_body_contact_forces(model: Model, state: State, particle_f: wp.array, friction_smoothing: float = 1.0):
if model.rigid_contact_max and (
model.ground and model.shape_ground_contact_pair_count or model.shape_contact_pair_count
):
@@ -1782,6 +1789,7 @@ def eval_body_contact_forces(model: Model, state: State, particle_f: wp.array):
model.rigid_contact_shape0,
model.rigid_contact_shape1,
False,
+ friction_smoothing,
],
outputs=[state.body_f],
device=model.device,
@@ -1880,7 +1888,15 @@ def eval_muscle_forces(model: Model, state: State, control: Control, body_f: wp.
)
-def compute_forces(model: Model, state: State, control: Control, particle_f: wp.array, body_f: wp.array, dt: float):
+def compute_forces(
+ model: Model,
+ state: State,
+ control: Control,
+ particle_f: wp.array,
+ body_f: wp.array,
+ dt: float,
+ friction_smoothing: float = 1.0,
+):
# damped springs
eval_spring_forces(model, state, particle_f)
@@ -1906,7 +1922,7 @@ def compute_forces(model: Model, state: State, control: Control, particle_f: wp.
eval_particle_ground_contact_forces(model, state, particle_f)
# body contacts
- eval_body_contact_forces(model, state, particle_f)
+ eval_body_contact_forces(model, state, particle_f, friction_smoothing=friction_smoothing)
# particle shape contact
eval_particle_body_contact_forces(model, state, particle_f, body_f, body_f_in_world_frame=False)
@@ -1941,12 +1957,14 @@ class SemiImplicitIntegrator(Integrator):
"""
- def __init__(self, angular_damping: float = 0.05):
+ def __init__(self, angular_damping: float = 0.05, friction_smoothing: float = 1.0):
"""
Args:
angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
+ friction_smoothing (float, optional): The delta value for the Huber norm (see :func:`warp.math.norm_huber`) used for the friction velocity normalization. Defaults to 1.0.
"""
self.angular_damping = angular_damping
+ self.friction_smoothing = friction_smoothing
def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
with wp.ScopedTimer("simulate", False):
@@ -1962,7 +1980,7 @@ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, c
if control is None:
control = model.control(clone_variables=False)
- compute_forces(model, state_in, control, particle_f, body_f, dt)
+ compute_forces(model, state_in, control, particle_f, body_f, dt, friction_smoothing=self.friction_smoothing)
self.integrate_bodies(model, state_in, state_out, dt, self.angular_damping)
diff --git a/warp/sim/integrator_featherstone.py b/warp/sim/integrator_featherstone.py
index 48b7f7487..c8d6557a2 100644
--- a/warp/sim/integrator_featherstone.py
+++ b/warp/sim/integrator_featherstone.py
@@ -1525,18 +1525,27 @@ class FeatherstoneIntegrator(Integrator):
"""
def __init__(
- self, model, angular_damping=0.05, update_mass_matrix_every=1, use_tile_gemm=False, fuse_cholesky=True
+ self,
+ model,
+ angular_damping=0.05,
+ update_mass_matrix_every=1,
+ friction_smoothing=1.0,
+ use_tile_gemm=False,
+ fuse_cholesky=True,
):
"""
Args:
model (Model): the model to be simulated.
angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
+ friction_smoothing (float, optional): The delta value for the Huber norm (see :func:`warp.math.norm_huber`) used for the friction velocity normalization. Defaults to 1.0.
"""
self.angular_damping = angular_damping
self.update_mass_matrix_every = update_mass_matrix_every
+ self.friction_smoothing = friction_smoothing
self.use_tile_gemm = use_tile_gemm
self.fuse_cholesky = fuse_cholesky
+
self._step = 0
self.compute_articulation_indices(model)
@@ -1834,6 +1843,7 @@ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, c
model.rigid_contact_shape0,
model.rigid_contact_shape1,
True,
+ self.friction_smoothing,
],
outputs=[body_f],
device=model.device,
diff --git a/warp/sim/model.py b/warp/sim/model.py
index da6803885..4230ac722 100644
--- a/warp/sim/model.py
+++ b/warp/sim/model.py
@@ -2785,7 +2785,7 @@ def add_shape_plane(
c = np.cross(normal, (0.0, 1.0, 0.0))
angle = np.arcsin(np.linalg.norm(c))
axis = np.abs(c) / np.linalg.norm(c)
- rot = wp.quat_from_axis_angle(axis, angle)
+ rot = wp.quat_from_axis_angle(wp.vec3(*axis), wp.float32(angle))
scale = wp.vec3(width, length, 0.0)
return self._add_shape(
diff --git a/warp/stubs.py b/warp/stubs.py
index 04d96c840..73d94d180 100644
--- a/warp/stubs.py
+++ b/warp/stubs.py
@@ -123,6 +123,8 @@
from . import builtins
from warp.builtins import static
+from warp.math import *
+
import warp.config as config
__version__ = config.version
@@ -3028,3 +3030,92 @@ def len(a: Array[Any]) -> int:
def len(a: Tile) -> int:
"""Return the number of rows in a tile."""
...
+
+
+@over
+def norm_l1(v: Any):
+ """Computes the L1 norm of a vector v.
+
+ .. math:: \|v\|_1 = \sum_i |v_i|
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the L1 norm of.
+
+ Returns:
+ float: The L1 norm of the vector.
+ """
+ ...
+
+
+@over
+def norm_l2(v: Any):
+ """Computes the L2 norm of a vector v.
+
+ .. math:: \|v\|_2 = \sqrt{\sum_i v_i^2}
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the L2 norm of.
+
+ Returns:
+ float: The L2 norm of the vector.
+ """
+ ...
+
+
+@over
+def norm_huber(v: Any, delta: float):
+ """Computes the Huber norm of a vector v with a given delta.
+
+ .. math::
+ H(v) = \begin{cases} \frac{1}{2} \|v\|^2 & \text{if } \|v\| \leq \delta \\ \delta(\|v\| - \frac{1}{2}\delta) & \text{otherwise} \end{cases}
+
+ .. image:: /img/norm_huber.svg
+ :align: center
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the Huber norm of.
+ delta (float): The threshold value, defaults to 1.0.
+
+ Returns:
+ float: The Huber norm of the vector.
+ """
+ ...
+
+
+@over
+def norm_pseudo_huber(v: Any, delta: float):
+ """Computes the "pseudo" Huber norm of a vector v with a given delta.
+
+ .. math::
+ H^\prime(v) = \delta \sqrt{1 + \frac{\|v\|^2}{\delta^2}}
+
+ .. image:: /img/norm_pseudo_huber.svg
+ :align: center
+
+ Args:
+ v (Vector[Any,Float]): The vector to compute the Huber norm of.
+ delta (float): The threshold value, defaults to 1.0.
+
+ Returns:
+ float: The Huber norm of the vector.
+ """
+ ...
+
+
+@over
+def smooth_normalize(v: Any, delta: float):
+ """Normalizes a vector using the pseudo-Huber norm.
+
+ See :func:`norm_pseudo_huber`.
+
+ .. math::
+ \frac{v}{H^\prime(v)}
+
+ Args:
+ v (Vector[Any,Float]): The vector to normalize.
+ delta (float): The threshold value, defaults to 1.0.
+
+ Returns:
+ Vector[Any,Float]: The normalized vector.
+ """
+ ...
diff --git a/warp/tests/test_math.py b/warp/tests/test_math.py
index 62425dacd..16af59a1a 100644
--- a/warp/tests/test_math.py
+++ b/warp/tests/test_math.py
@@ -6,7 +6,7 @@
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import unittest
-from typing import NamedTuple
+from typing import Any, NamedTuple
import numpy as np
@@ -50,6 +50,51 @@ def test_scalar_math(test, device):
assert_np_equal(tape.gradients[x].numpy(), np.array([adj_float_results_expected[i]]), tol=1e-6)
+@wp.kernel
+def test_vec_norm_kernel(vs: wp.array(dtype=Any), out: wp.array(dtype=float, ndim=2)):
+ tid = wp.tid()
+ out[tid, 0] = wp.norm_l1(vs[tid])
+ out[tid, 1] = wp.norm_l2(vs[tid])
+ out[tid, 2] = wp.norm_huber(vs[tid])
+ out[tid, 3] = wp.norm_pseudo_huber(vs[tid])
+
+
+def test_vec_norm(test, device):
+ # ground-truth implementations from SciPy
+ def huber(delta, x):
+ if x <= delta:
+ return 0.5 * x**2
+ else:
+ return delta * (x - 0.5 * delta)
+
+ def pseudo_huber(delta, x):
+ return delta**2 * (np.sqrt(1 + (x / delta) ** 2) - 1)
+
+ v0 = wp.vec3(-2.0, -1.0, -3.0)
+ v1 = wp.vec3(2.0, 1.0, 3.0)
+ v2 = wp.vec3(0.0, 0.0, 0.0)
+
+ xs = wp.array([v0, v1, v2], dtype=wp.vec3, requires_grad=True, device=device)
+ out = wp.empty((len(xs), 4), dtype=wp.float32, requires_grad=True, device=device)
+
+ wp.launch(test_vec_norm_kernel, dim=len(xs), inputs=[xs], outputs=[out], device=device)
+
+ for i, x in enumerate([v0, v1, v2]):
+ assert_np_equal(
+ out.numpy()[i],
+ np.array(
+ [
+ np.linalg.norm(x, ord=1),
+ np.linalg.norm(x, ord=2),
+ huber(1.0, wp.length(x)),
+ # note SciPy defines the Pseudo-Huber loss slightly differently
+ pseudo_huber(1.0, wp.length(x)) + 1.0,
+ ]
+ ),
+ tol=1e-6,
+ )
+
+
devices = get_test_devices()
@@ -117,6 +162,7 @@ def test_mat_type(self):
add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
+add_function_test(TestMath, "test_vec_norm", test_vec_norm, devices=devices)
if __name__ == "__main__":
diff --git a/warp/tests/test_sim_grad.py b/warp/tests/test_sim_grad.py
index 7b6d677e7..0c5cf78bd 100644
--- a/warp/tests/test_sim_grad.py
+++ b/warp/tests/test_sim_grad.py
@@ -11,6 +11,7 @@
import warp as wp
import warp.sim
+import warp.sim.render
from warp.tests.unittest_utils import *
@@ -23,8 +24,7 @@ def evaluate_loss(
loss: wp.array(dtype=float),
):
tid = wp.tid()
- # wp.atomic_add(loss, 0, weighting * (target - joint_q[tid * 2 + 1]) ** 2.0)
- d = wp.abs(target - joint_q[tid * 2 + 1])
+ d = (target - joint_q[tid * 2 + 1]) ** 2.0
wp.atomic_add(loss, 0, weighting * d)
@@ -34,7 +34,13 @@ def assign_action(action: wp.array(dtype=float), joint_act: wp.array(dtype=float
joint_act[2 * tid] = action[tid]
-def gradcheck(func, inputs, device, eps=1e-1, tol=1e-2):
+@wp.kernel
+def assign_force(action: wp.array(dtype=float), body_f: wp.array(dtype=wp.spatial_vector)):
+ tid = wp.tid()
+ body_f[2 * tid] = wp.spatial_vector(0.0, 0.0, 0.0, action[tid], 0.0, 0.0)
+
+
+def gradcheck(func, inputs, device, eps=1e-1, tol=1e-2, print_grad=False):
"""
Checks that the gradient of the Warp kernel is correct by comparing it to the
numerical gradient computed using finite differences.
@@ -46,56 +52,64 @@ def f(xs):
output = func(*wp_xs)
return output.numpy()[0]
+ # compute analytical gradient
+ tape = wp.Tape()
+ with tape:
+ output = func(*inputs)
+
+ tape.backward(loss=output)
+
# compute numerical gradient
- numerical_grad = []
np_xs = []
for i in range(len(inputs)):
np_xs.append(inputs[i].numpy().flatten().copy())
- numerical_grad.append(np.zeros_like(np_xs[-1]))
- inputs[i].requires_grad = True
- for i in range(len(np_xs)):
+ for i in range(len(inputs)):
+ fd_grad = np.zeros_like(np_xs[i])
for j in range(len(np_xs[i])):
np_xs[i][j] += eps
y1 = f(np_xs)
np_xs[i][j] -= 2 * eps
y2 = f(np_xs)
np_xs[i][j] += eps
- numerical_grad[i][j] = (y1 - y2) / (2 * eps)
-
- # compute analytical gradient
- tape = wp.Tape()
- with tape:
- output = func(*inputs)
-
- tape.backward(loss=output)
-
- # compare gradients
- for i in range(len(inputs)):
- grad = tape.gradients[inputs[i]]
- assert_np_equal(grad.numpy(), numerical_grad[i], tol=tol)
+ fd_grad[j] = (y1 - y2) / (2 * eps)
+
+ # compare gradients
+ ad_grad = tape.gradients[inputs[i]].numpy()
+ if print_grad:
+ print("grad ad:", ad_grad)
+ print("grad fd:", fd_grad)
+ assert_np_equal(ad_grad, fd_grad, tol=tol)
# ensure the signs match
- assert np.allclose(grad.numpy() * numerical_grad[i] > 0, True)
+ assert np.allclose(ad_grad * fd_grad > 0, True)
tape.zero()
-def test_box_pushing_on_rails(test, device, joint_type, integrator_type):
- # Two boxes on a rail (prismatic or D6 joint), one is pushed, the other is passive.
+def test_sphere_pushing_on_rails(
+ test,
+ device,
+ joint_type,
+ integrator_type,
+ apply_force=False,
+ static_contacts=True,
+ print_grad=False,
+):
+ # Two spheres on a rail (prismatic or D6 joint), one is pushed, the other is passive.
# The absolute distance to a target is measured and gradients are compared for
# a push that is too far and too close.
num_envs = 2
- num_steps = 200
- sim_substeps = 2
+ num_steps = 150
+ sim_substeps = 10
dt = 1 / 30
- target = 5.0
+ target = 3.0
if integrator_type == 0:
- contact_ke = 1e5
- contact_kd = 1e3
+ contact_ke = 1e3
+ contact_kd = 1e1
else:
- contact_ke = 1e5
+ contact_ke = 1e3
contact_kd = 1e1
complete_builder = wp.sim.ModelBuilder()
@@ -104,16 +118,16 @@ def test_box_pushing_on_rails(test, device, joint_type, integrator_type):
complete_builder.default_shape_kd = contact_kd
for _ in range(num_envs):
- builder = wp.sim.ModelBuilder()
+ builder = wp.sim.ModelBuilder(gravity=0.0)
builder.default_shape_ke = complete_builder.default_shape_ke
builder.default_shape_kd = complete_builder.default_shape_kd
b0 = builder.add_body(name="pusher")
- builder.add_shape_box(b0, density=1000.0)
+ builder.add_shape_sphere(b0, radius=0.4, density=100.0)
b1 = builder.add_body(name="passive")
- builder.add_shape_box(b1, hx=0.4, hy=0.4, hz=0.4, density=1000.0)
+ builder.add_shape_sphere(b1, radius=0.47, density=100.0)
if joint_type == 0:
builder.add_joint_prismatic(-1, b0)
@@ -122,7 +136,7 @@ def test_box_pushing_on_rails(test, device, joint_type, integrator_type):
builder.add_joint_d6(-1, b0, linear_axes=[wp.sim.JointAxis((1.0, 0.0, 0.0))])
builder.add_joint_d6(-1, b1, linear_axes=[wp.sim.JointAxis((1.0, 0.0, 0.0))])
- builder.joint_q[-2:] = [0.0, 1.0]
+ builder.joint_q[-2:] = [0.0, 2.0]
complete_builder.add_builder(builder)
assert complete_builder.body_count == 2 * num_envs
@@ -135,6 +149,15 @@ def test_box_pushing_on_rails(test, device, joint_type, integrator_type):
model.joint_attach_ke = 32000.0 * 16
model.joint_attach_kd = 500.0 * 4
+ model.shape_geo.scale.requires_grad = False
+ model.shape_geo.thickness.requires_grad = False
+
+ if static_contacts:
+ wp.sim.eval_fk(model, model.joint_q, model.joint_qd, None, model)
+ model.rigid_contact_margin = 10.0
+ state = model.state()
+ wp.sim.collide(model, state)
+
if integrator_type == 0:
integrator = wp.sim.FeatherstoneIntegrator(model, update_mass_matrix_every=num_steps * sim_substeps)
elif integrator_type == 1:
@@ -143,40 +166,57 @@ def test_box_pushing_on_rails(test, device, joint_type, integrator_type):
else:
integrator = wp.sim.XPBDIntegrator(iterations=2, rigid_contact_relaxation=1.0)
- # renderer = wp.sim.render.SimRenderer(model, "test_sim_grad.usd", scaling=1.0)
+ # renderer = wp.sim.render.SimRendererOpenGL(model, "test_sim_grad.usd", scaling=1.0)
renderer = None
render_time = 0.0
+ if renderer:
+ renderer.render_sphere("target", pos=wp.vec3(target, 0, 0), rot=wp.quat_identity(), radius=0.1, color=(1, 0, 0))
+
def rollout(action: wp.array) -> wp.array:
nonlocal render_time
states = [model.state() for _ in range(num_steps * sim_substeps + 1)]
- if not isinstance(integrator, wp.sim.FeatherstoneIntegrator):
- # apply initial generalized coordinates
- wp.sim.eval_fk(model, model.joint_q, model.joint_qd, None, states[0])
+ wp.sim.eval_fk(model, model.joint_q, model.joint_qd, None, states[0])
control_active = model.control()
control_nop = model.control()
- wp.launch(
- assign_action,
- dim=num_envs,
- inputs=[action],
- outputs=[control_active.joint_act],
- device=model.device,
- )
+ if not apply_force:
+ wp.launch(
+ assign_action,
+ dim=num_envs,
+ inputs=[action],
+ outputs=[control_active.joint_act],
+ device=model.device,
+ )
i = 0
for step in range(num_steps):
- wp.sim.collide(model, states[i])
- control = control_active if step < 10 else control_nop
+ state = states[i]
+ if not static_contacts:
+ wp.sim.collide(model, state)
+ if apply_force:
+ control = control_nop
+ else:
+ control = control_active if step < 10 else control_nop
if renderer:
renderer.begin_frame(render_time)
- renderer.render(states[i])
+ renderer.render(state)
renderer.end_frame()
render_time += dt
for _ in range(sim_substeps):
- integrator.simulate(model, states[i], states[i + 1], dt / sim_substeps, control)
+ state = states[i]
+ next_state = states[i + 1]
+ if apply_force and step < 10:
+ wp.launch(
+ assign_force,
+ dim=num_envs,
+ inputs=[action],
+ outputs=[state.body_f],
+ device=model.device,
+ )
+ integrator.simulate(model, state, next_state, dt / sim_substeps, control)
i += 1
if not isinstance(integrator, wp.sim.FeatherstoneIntegrator):
@@ -184,39 +224,40 @@ def rollout(action: wp.array) -> wp.array:
wp.sim.eval_ik(model, states[-1], states[-1].joint_q, states[-1].joint_qd)
loss = wp.zeros(1, requires_grad=True, device=device)
+ weighting = 1.0
wp.launch(
evaluate_loss,
dim=num_envs,
- inputs=[states[-1].joint_q, 1.0, target],
+ inputs=[states[-1].joint_q, weighting, target],
outputs=[loss],
device=model.device,
)
- if renderer:
- renderer.save()
+ # if renderer:
+ # renderer.save()
return loss
action_too_far = wp.array(
- [5000.0 for _ in range(num_envs)],
+ [80.0 for _ in range(num_envs)],
device=device,
dtype=wp.float32,
requires_grad=True,
)
- tol = 1e-2
+ tol = 2e-1
if isinstance(integrator, wp.sim.XPBDIntegrator):
# Euler, XPBD do not yield as accurate gradients, but at least the
# signs should match
tol = 0.1
- gradcheck(rollout, [action_too_far], device=device, eps=0.2, tol=tol)
+ gradcheck(rollout, [action_too_far], device=device, eps=0.2, tol=tol, print_grad=print_grad)
action_too_close = wp.array(
- [1500.0 for _ in range(num_envs)],
+ [40.0 for _ in range(num_envs)],
device=device,
dtype=wp.float32,
requires_grad=True,
)
- gradcheck(rollout, [action_too_close], device=device, eps=0.2, tol=tol)
+ gradcheck(rollout, [action_too_close], device=device, eps=0.2, tol=tol, print_grad=print_grad)
devices = get_test_devices()
@@ -226,15 +267,15 @@ class TestSimGradients(unittest.TestCase):
pass
-for int_type, int_name in enumerate(["featherstone", "semiimplicit"]):
- for jt_type, jt_name in enumerate(["prismatic", "d6"]):
- test_name = f"test_box_pushing_on_rails_{int_name}_{jt_name}"
+for jt_type, jt_name in enumerate(["prismatic", "d6"]):
+ test_name = f"test_sphere_pushing_on_rails_{jt_name}"
- def test_fn(self, device, jt_type=jt_type, int_type=int_type):
- return test_box_pushing_on_rails(self, device, jt_type, int_type)
-
- add_function_test(TestSimGradients, test_name, test_fn, devices=devices)
+ def test_fn(self, device, jt_type=jt_type, int_type=1):
+ return test_sphere_pushing_on_rails(
+ self, device, jt_type, int_type, apply_force=True, static_contacts=True, print_grad=False
+ )
+ add_function_test(TestSimGradients, test_name, test_fn, devices=devices)
if __name__ == "__main__":
wp.clear_kernel_cache()
diff --git a/warp/tests/test_sim_grad_bounce_linear.py b/warp/tests/test_sim_grad_bounce_linear.py
new file mode 100644
index 000000000..d81724c70
--- /dev/null
+++ b/warp/tests/test_sim_grad_bounce_linear.py
@@ -0,0 +1,196 @@
+import numpy as np
+
+import warp as wp
+import warp.optim
+import warp.sim
+import warp.sim.render
+from warp.tests.unittest_utils import *
+
+
+@wp.kernel
+def update_trajectory_kernel(
+ trajectory: wp.array(dtype=wp.vec3),
+ q: wp.array(dtype=wp.transform),
+ time_step: wp.int32,
+ q_idx: wp.int32,
+):
+ trajectory[time_step] = wp.transform_get_translation(q[q_idx])
+
+
+@wp.kernel
+def trajectory_loss_kernel(
+ trajectory: wp.array(dtype=wp.vec3f),
+ target_trajectory: wp.array(dtype=wp.vec3f),
+ loss: wp.array(dtype=wp.float32),
+):
+ tid = wp.tid()
+ diff = trajectory[tid] - target_trajectory[tid]
+ distance_loss = wp.dot(diff, diff)
+ wp.atomic_add(loss, 0, distance_loss)
+
+
+class BallBounceLinearTest:
+ def __init__(self, gravity=True, rendering=False):
+ # Ball bouncing scenario inspired by https://github.com/NVIDIA/warp/issues/349
+ self.fps = 30
+ self.num_frames = 60
+ self.sim_substeps = 20 # XXX need to use enough substeps to achieve smooth gradients
+ self.frame_dt = 1.0 / self.fps
+ self.sim_dt = self.frame_dt / self.sim_substeps
+ self.sim_duration = self.num_frames * self.frame_dt
+ self.sim_steps = int(self.sim_duration // self.sim_dt)
+
+ self.target_force_linear = 100.0
+
+ if gravity:
+ builder = wp.sim.ModelBuilder(up_vector=wp.vec3(0, 0, 1))
+ else:
+ builder = wp.sim.ModelBuilder(gravity=0.0, up_vector=wp.vec3(0, 0, 1))
+
+ b = builder.add_body(origin=wp.transform((0.5, 0.0, 1.0), wp.quat_identity()), name="ball")
+ builder.add_shape_sphere(
+ body=b, radius=0.1, density=100.0, ke=2000.0, kd=10.0, kf=200.0, mu=0.2, thickness=0.01
+ )
+ builder.set_ground_plane(ke=10, kd=10, kf=0.0, mu=0.2)
+ self.model = builder.finalize(requires_grad=True)
+
+ self.time = np.linspace(0, self.sim_duration, self.sim_steps)
+
+ self.integrator = wp.sim.SemiImplicitIntegrator()
+ if rendering:
+ self.renderer = wp.sim.render.SimRendererOpenGL(self.model, "ball_bounce_linear")
+ else:
+ self.renderer = None
+
+ self.loss = wp.array([0], dtype=wp.float32, requires_grad=True)
+ self.states = [self.model.state() for _ in range(self.sim_steps + 1)]
+ self.target_states = [self.model.state() for _ in range(self.sim_steps + 1)]
+
+ self.target_force = wp.array([0.0, 0.0, 0.0, 0.0, self.target_force_linear, 0.0], dtype=wp.spatial_vectorf)
+
+ self.trajectory = wp.empty(len(self.time), dtype=wp.vec3, requires_grad=True)
+ self.target_trajectory = wp.empty(len(self.time), dtype=wp.vec3)
+
+ def _reset(self):
+ self.loss = wp.array([0], dtype=wp.float32, requires_grad=True)
+
+ def generate_target_trajectory(self):
+ for i in range(self.sim_steps):
+ curr_state = self.target_states[i]
+ next_state = self.target_states[i + 1]
+ curr_state.clear_forces()
+ if i == 0:
+ wp.copy(curr_state.body_f, self.target_force, dest_offset=0, src_offset=0, count=1)
+ wp.sim.collide(self.model, curr_state)
+ self.integrator.simulate(self.model, curr_state, next_state, self.sim_dt)
+ wp.launch(kernel=update_trajectory_kernel, dim=1, inputs=[self.target_trajectory, curr_state.body_q, i, 0])
+
+ def forward(self, force: wp.array):
+ for i in range(self.sim_steps):
+ curr_state = self.states[i]
+ next_state = self.states[i + 1]
+ curr_state.clear_forces()
+ if i == 0:
+ wp.copy(curr_state.body_f, force, dest_offset=0, src_offset=0, count=1)
+ wp.sim.collide(self.model, curr_state)
+ self.integrator.simulate(self.model, curr_state, next_state, self.sim_dt)
+ wp.launch(kernel=update_trajectory_kernel, dim=1, inputs=[self.trajectory, curr_state.body_q, i, 0])
+
+ if self.renderer:
+ self.renderer.begin_frame(self.time[i])
+ self.renderer.render(curr_state)
+ self.renderer.end_frame()
+
+ def step(self, force: wp.array):
+ self.tape = wp.Tape()
+ self._reset()
+ with self.tape:
+ self.forward(force)
+ wp.launch(
+ kernel=trajectory_loss_kernel,
+ dim=len(self.trajectory),
+ inputs=[self.trajectory, self.target_trajectory, self.loss],
+ )
+ self.tape.backward(self.loss)
+ force_grad = force.grad.numpy()[0, 4]
+ self.tape.zero()
+
+ return self.loss.numpy()[0], force_grad
+
+ def evaluate(self, num_samples, plot_results=False):
+ forces = np.linspace(0, self.target_force_linear * 2, num_samples)
+ losses = np.zeros_like(forces)
+ grads = np.zeros_like(forces)
+
+ for i, fx in enumerate(forces):
+ force = wp.array([[0.0, 0.0, 0.0, 0.0, fx, 0.0]], dtype=wp.spatial_vectorf, requires_grad=True)
+ losses[i], grads[i] = self.step(force)
+ if plot_results:
+ print(f"Iteration {i + 1}/{num_samples}")
+ print(f"Force: {fx:.2f}, Loss: {losses[i]:.6f}, Grad: {grads[i]:.6f}")
+
+ assert np.isfinite(losses[i])
+ assert np.isfinite(grads[i])
+ if i > 0:
+ assert grads[i] >= grads[i - 1]
+
+ if plot_results:
+ import matplotlib.pyplot as plt
+
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
+
+ # Plot the loss curve
+ ax1.plot(forces, losses, label="Loss")
+ ax1.set_xlabel("Force")
+ ax1.set_ylabel("Loss")
+ ax1.set_title("Loss vs Force")
+ ax1.legend()
+
+ # Make sure the grads are not too large
+ grads = np.clip(grads, -1e4, 1e4)
+
+ # Plot the gradient curve
+ ax2.plot(forces, grads, label="Gradient", color="orange")
+ ax2.set_xlabel("Force")
+ ax2.set_ylabel("Gradient")
+ ax2.set_title("Gradient vs Force")
+ ax2.legend()
+
+ plt.suptitle("Loss and Gradient vs Force")
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
+ plt.show()
+
+ return losses, grads
+
+
+def test_sim_grad_bounce_linear(test, device):
+ model = BallBounceLinearTest()
+ model.generate_target_trajectory()
+
+ num_samples = 20
+ losses, grads = model.evaluate(num_samples=num_samples)
+ # gradients must approximate linear behavior with zero crossing in the middle
+ test.assertTrue(np.abs(grads[1:] - grads[:-1]).max() < 1.1)
+ test.assertTrue(np.all(grads[: num_samples // 2] <= 0.0))
+ test.assertTrue(np.all(grads[num_samples // 2 :] >= 0.0))
+ # losses must follow a parabolic behavior
+ test.assertTrue(np.allclose(losses[: num_samples // 2], losses[num_samples // 2 :][::-1], atol=1.0))
+ diffs = losses[1:] - losses[:-1]
+ test.assertTrue(np.all(diffs[: num_samples // 2 - 1] <= 0.0))
+ test.assertTrue(np.all(diffs[num_samples // 2 - 1 :] >= 0.0))
+ # second derivative must be constant positive
+ diffs2 = diffs[1:] - diffs[:-1]
+ test.assertTrue(np.allclose(diffs2, diffs2[0], atol=1e-2))
+ test.assertTrue(np.all(diffs2 >= 0.0))
+
+
+class TestSimGradBounceLinear(unittest.TestCase):
+ pass
+
+
+devices = get_test_devices()
+add_function_test(TestSimGradBounceLinear, "test_sim_grad_bounce_linear", test_sim_grad_bounce_linear, devices=devices)
+
+if __name__ == "__main__":
+ wp.clear_kernel_cache()
+ unittest.main(verbosity=2, failfast=True)