Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mujoco_warp/_src/broadphase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from mujoco_warp import BroadphaseType
from mujoco_warp import DisableBit
from mujoco_warp import test_data

from . import collision_driver
from mujoco_warp._src import collision_driver


def broadphase_caller(m, d):
Expand Down
3 changes: 1 addition & 2 deletions mujoco_warp/_src/collision_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
from mujoco_warp import BroadphaseType
from mujoco_warp import DisableBit
from mujoco_warp import test_data
from mujoco_warp._src import types
from mujoco_warp._src.collision_primitive import Geom
from mujoco_warp._src.collision_primitive import plane_convex
from mujoco_warp.test_data.collision_sdf.utils import register_sdf_plugins

from . import types

_TOLERANCE = 5e-5


Expand Down
13 changes: 6 additions & 7 deletions mujoco_warp/_src/collision_gjk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
from mujoco_warp import GeomType
from mujoco_warp import Model
from mujoco_warp import test_data

from .collision_gjk import ccd
from .collision_gjk import multicontact
from .collision_primitive import Geom
from .types import MJ_MAX_EPAFACES
from .types import MJ_MAX_EPAHORIZON
from .warp_util import nested_kernel
from mujoco_warp._src.collision_gjk import ccd
from mujoco_warp._src.collision_gjk import multicontact
from mujoco_warp._src.collision_primitive import Geom
from mujoco_warp._src.types import MJ_MAX_EPAFACES
from mujoco_warp._src.types import MJ_MAX_EPAHORIZON
from mujoco_warp._src.warp_util import nested_kernel


def _geom_dist(
Expand Down
3 changes: 1 addition & 2 deletions mujoco_warp/_src/inverse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from mujoco_warp import DisableBit
from mujoco_warp import IntegratorType
from mujoco_warp import test_data

from . import inverse
from mujoco_warp._src import inverse


def _assert_eq(a, b, name):
Expand Down
16 changes: 8 additions & 8 deletions mujoco_warp/_src/io_jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _dims_match(test_obj, d1: Any, d2: Any, prefix: str = ""):
_dims_match(test_obj, a1, a2, prefix + f1.name + ".")
continue

if isinstance(f1.type, wp.types.array) or isinstance(f2.type, wp.types.array):
if isinstance(f1.type, wp.array) or isinstance(f2.type, wp.array):
s1, s2 = a1.shape, a2.shape
test_obj.assertEqual(len(s1), len(s2), f"{full_name} dims mismatch. Got {s1} and {s2}.")
test_obj.assertEqual(s1, s2, f"{full_name} dims mismatch. Got {s1} and {s2}.")
Expand Down Expand Up @@ -85,8 +85,8 @@ def _check_type_matches_annotation(test_obj, obj: Any, prefix: str = ""):
test_obj.assertIsInstance(np_scalar_type(val), type_, msg.format(**locals()))
continue

if isinstance(type_, wp.types.array):
test_obj.assertIsInstance(val, wp.types.array, msg.format(**locals()))
if isinstance(type_, wp.array):
test_obj.assertIsInstance(val, wp.array, msg.format(**locals()))
continue

origin_type = typing.get_origin(type_)
Expand All @@ -107,8 +107,8 @@ def _check_type_matches_annotation(test_obj, obj: Any, prefix: str = ""):
test_obj.assertIsInstance(np_scalar_type(val), type_, msg.format(**locals()))
continue

if isinstance(type_, wp.types.array):
test_obj.assertIsInstance(val, wp.types.array, msg.format(**locals()))
if isinstance(type_, wp.array):
test_obj.assertIsInstance(val, wp.array, msg.format(**locals()))
continue

test_obj.assertEqual(type(val), type_, msg.format(**locals()))
Expand All @@ -128,10 +128,10 @@ def _check_annotation_compat(
if v in (int, bool, float):
continue

if isinstance(v, wp.types.array):
if isinstance(v, wp.array):
continue

if v in wp.types.vector_types:
if wp.types.type_is_composite(v):
raise AssertionError(f"Vector types are not allowed. {info}")

if typing.get_origin(v) == tuple and (in_cls or in_tuple):
Expand Down Expand Up @@ -177,7 +177,7 @@ def _leading_dims_scale_w_nworld(test_obj, d1: Any, d2: Any, nworld1: int, nworl
_leading_dims_scale_w_nworld(test_obj, a1, a2, nworld1, nworld2, prefix + f1.name + ".")
continue

if isinstance(f1.type, wp.types.array) or isinstance(f2.type, wp.types.array):
if isinstance(f1.type, wp.array) or isinstance(f2.type, wp.array):
s1, s2 = a1.shape[0], a2.shape[0]
if s1 == s2:
continue
Expand Down
6 changes: 3 additions & 3 deletions mujoco_warp/_src/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import warp as wp
from absl.testing import absltest

from .math import closest_segment_to_segment_points
from .math import upper_tri_index
from .math import upper_trid_index
from mujoco_warp._src.math import closest_segment_to_segment_points
from mujoco_warp._src.math import upper_tri_index
from mujoco_warp._src.math import upper_trid_index


class ClosestSegmentSegmentPointsTest(absltest.TestCase):
Expand Down
14 changes: 7 additions & 7 deletions mujoco_warp/_src/passive.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,15 +595,15 @@ def _flex_elasticity(
nedge = nvert * (nvert - 1) / 2
edges = wp.where(
dim == 3,
wp.types.matrix(0, 1, 1, 2, 2, 0, 2, 3, 0, 3, 1, 3, shape=(6, 2), dtype=int),
wp.types.matrix(1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int),
wp.matrix(0, 1, 1, 2, 2, 0, 2, 3, 0, 3, 1, 3, shape=(6, 2), dtype=int),
wp.matrix(1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int),
)
if timestep > 0.0 and not dsbl_damper:
kD = flex_damping[f] / timestep
else:
kD = 0.0

gradient = wp.types.matrix(0.0, shape=(6, 6))
gradient = wp.matrix(0.0, shape=(6, 6))
for e in range(nedge):
vert0 = flex_elem[(dim + 1) * elemid + edges[e, 0]]
vert1 = flex_elem[(dim + 1) * elemid + edges[e, 1]]
Expand All @@ -622,15 +622,15 @@ def _flex_elasticity(
previous = deformed - vel * timestep
elongation[e] = deformed * deformed - reference * reference + (deformed * deformed - previous * previous) * kD

metric = wp.types.matrix(0.0, shape=(6, 6))
metric = wp.matrix(0.0, shape=(6, 6))
id = int(0)
for ed1 in range(nedge):
for ed2 in range(ed1, nedge):
metric[ed1, ed2] = flex_stiffness[elemid, id]
metric[ed2, ed1] = flex_stiffness[elemid, id]
id += 1

force = wp.types.matrix(0.0, shape=(6, 3))
force = wp.matrix(0.0, shape=(6, 3))
for ed1 in range(nedge):
for ed2 in range(nedge):
for i in range(2):
Expand Down Expand Up @@ -684,7 +684,7 @@ def _flex_bending(
flex_vertadr[f] + flex_edgeflap[edgeid][1],
)

frc = wp.types.matrix(0.0, shape=(4, 3))
frc = wp.matrix(0.0, shape=(4, 3))
if flex_bending[edgeid, 16]:
v0 = flexvert_xpos_in[worldid, v[0]]
v1 = flexvert_xpos_in[worldid, v[1]]
Expand All @@ -695,7 +695,7 @@ def _flex_bending(
frc[3] = wp.cross(v1 - v0, v2 - v0)
frc[0] = -(frc[1] + frc[2] + frc[3])

force = wp.types.matrix(0.0, shape=(nvert, 3))
force = wp.matrix(0.0, shape=(nvert, 3))
for i in range(nvert):
for x in range(3):
for j in range(nvert):
Expand Down
3 changes: 1 addition & 2 deletions mujoco_warp/_src/ray_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

import mujoco_warp as mjw
from mujoco_warp import test_data

from .types import vec6
from mujoco_warp._src.types import vec6

# tolerance for difference between MuJoCo and MJX ray calculations - mostly
# due to float precision
Expand Down
3 changes: 1 addition & 2 deletions mujoco_warp/_src/solver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from mujoco_warp import ConeType
from mujoco_warp import SolverType
from mujoco_warp import test_data

from . import solver
from mujoco_warp._src import solver

# tolerance for difference between MuJoCo and MJWarp solver calculations - mostly
# due to float precision
Expand Down
7 changes: 3 additions & 4 deletions mujoco_warp/_src/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
from mujoco_warp import ConeType
from mujoco_warp import State
from mujoco_warp import test_data

from .block_cholesky import create_blocked_cholesky_func
from .block_cholesky import create_blocked_cholesky_solve_func
from .warp_util import nested_kernel
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_func
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func
from mujoco_warp._src.warp_util import nested_kernel

# tolerance for difference between MuJoCo and MJWarp support calculations - mostly
# due to float precision
Expand Down
6 changes: 3 additions & 3 deletions mujoco_warp/_src/types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from absl.testing import absltest
from absl.testing import parameterized

from .types import Data
from .types import Model
from .types import Option
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import Option


class TypesTest(parameterized.TestCase):
Expand Down
5 changes: 2 additions & 3 deletions mujoco_warp/_src/unroll_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
import mujoco_warp as mjw
from mujoco_warp import ConeType
from mujoco_warp import test_data

from .io import find_keys
from .io import make_trajectory
from mujoco_warp._src.io import find_keys
from mujoco_warp._src.io import make_trajectory


class UnrollTest(parameterized.TestCase):
Expand Down
8 changes: 4 additions & 4 deletions mujoco_warp/_src/util_misc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from absl.testing import absltest
from absl.testing import parameterized

from . import util_misc
from .types import MJ_MINVAL
from .types import WrapType
from .types import vec10
from mujoco_warp._src import util_misc
from mujoco_warp._src.types import MJ_MINVAL
from mujoco_warp._src.types import WrapType
from mujoco_warp._src.types import vec10


def _assert_eq(a, b, name):
Expand Down
13 changes: 5 additions & 8 deletions mujoco_warp/_src/warp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from typing import Callable, Optional

import warp as wp
from warp._src.context import Module
from warp._src.context import get_module

_STACK = None

Expand Down Expand Up @@ -127,7 +125,7 @@ def nested_kernel(
f: Optional[Callable] = None,
*,
enable_backward: Optional[bool] = None,
module: Optional[Module] = None,
module: Optional[wp.Module] = None,
):
"""Decorator to register a Warp kernel from a Python function.

Expand Down Expand Up @@ -166,7 +164,7 @@ def my_kernel_with_args(a: wp.array(dtype=float), b: wp.array(dtype=float)):
Args:
f: The function to be registered as a kernel.
enable_backward: If False, the backward pass will not be generated.
module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively,
module: The :class:`warp.Module` to which the kernel belongs. Alternatively,
if a string `"unique"` is provided, the kernel is assigned to a new module
named after the kernel name and hash. If None, the module is inferred from
the function's module.
Expand All @@ -182,7 +180,7 @@ def decorator(func):
qualname = func.__qualname__
parts = [part for part in qualname.split(".") if part != "<locals>"]
outer_functions = parts[:-1]
module_name = get_module(".".join([func.__module__] + outer_functions))
module_name = wp.get_module(".".join([func.__module__] + outer_functions))
else:
module_name = module

Expand Down Expand Up @@ -220,8 +218,7 @@ def _hash_arg(a):


def check_toolkit_driver():
if wp._src.context.runtime is None:
wp._src.context.init()
wp.init()
if wp.get_device().is_cuda:
if wp._src.context.runtime.toolkit_version < (12, 4) or wp._src.context.runtime.driver_version < (12, 4):
if not wp.is_conditional_graph_supported():
RuntimeError("Minimum supported CUDA version: 12.4.")
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"etils[epath]",
"mujoco>=3.3.7",
"numpy",
"warp-lang>=1.9.1",
"warp-lang>=1.11.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these changes? we should only change this if mjwarp no longer works with >=1.9.1 - is that the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. For example, the current MR replaces internal references like:

if wp._src.context.runtime.toolkit_version < (12, 4) or wp._src.context.runtime.driver_version < (12, 4):

with:

if not wp.is_conditional_graph_supported():

The latter was introduced as part of the public API overhaul commit 2d6d379a, and released in v1.11.0.

This MR also uses wp.type_is_composite(), another API introduced in v1.11.0, to replace a reference to the internal symbol wp.types.vector_types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK, understood. SGTM.

]

[[tool.uv.index]]
Expand All @@ -55,7 +55,7 @@ dev = [
"pygls>=1.0.0,<2.0.0",
"lsprotocol>=2023.0.1,<2024.0.0",
"mujoco>=3.3.7.dev0",
"warp-lang>=1.9.1.dev0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We put this here because when we develop, we want to use a Warp nightly. Please revert this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! ✔️

"warp-lang>=1.11.0.dev0",
]
# TODO(team): cpu and cuda JAX optional dependencies are temporary, remove after we land MJX:Warp
cpu = [
Expand Down
12 changes: 6 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading