diff --git a/mujoco_warp/_src/broadphase_test.py b/mujoco_warp/_src/broadphase_test.py index 1e1e35c63..bc4027313 100644 --- a/mujoco_warp/_src/broadphase_test.py +++ b/mujoco_warp/_src/broadphase_test.py @@ -25,8 +25,7 @@ from mujoco_warp import BroadphaseType from mujoco_warp import DisableBit from mujoco_warp import test_data - -from . import collision_driver +from mujoco_warp._src import collision_driver def broadphase_caller(m, d): diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py index 3487279eb..a309b38ae 100644 --- a/mujoco_warp/_src/collision_driver_test.py +++ b/mujoco_warp/_src/collision_driver_test.py @@ -24,12 +24,11 @@ from mujoco_warp import BroadphaseType from mujoco_warp import DisableBit from mujoco_warp import test_data +from mujoco_warp._src import types from mujoco_warp._src.collision_primitive import Geom from mujoco_warp._src.collision_primitive import plane_convex from mujoco_warp.test_data.collision_sdf.utils import register_sdf_plugins -from . import types - _TOLERANCE = 5e-5 diff --git a/mujoco_warp/_src/collision_gjk_test.py b/mujoco_warp/_src/collision_gjk_test.py index 92999a111..b49937320 100644 --- a/mujoco_warp/_src/collision_gjk_test.py +++ b/mujoco_warp/_src/collision_gjk_test.py @@ -21,13 +21,12 @@ from mujoco_warp import GeomType from mujoco_warp import Model from mujoco_warp import test_data - -from .collision_gjk import ccd -from .collision_gjk import multicontact -from .collision_primitive import Geom -from .types import MJ_MAX_EPAFACES -from .types import MJ_MAX_EPAHORIZON -from .warp_util import nested_kernel +from mujoco_warp._src.collision_gjk import ccd +from mujoco_warp._src.collision_gjk import multicontact +from mujoco_warp._src.collision_primitive import Geom +from mujoco_warp._src.types import MJ_MAX_EPAFACES +from mujoco_warp._src.types import MJ_MAX_EPAHORIZON +from mujoco_warp._src.warp_util import nested_kernel def _geom_dist( diff --git a/mujoco_warp/_src/inverse_test.py b/mujoco_warp/_src/inverse_test.py index bcaedd009..319a773d0 100644 --- a/mujoco_warp/_src/inverse_test.py +++ b/mujoco_warp/_src/inverse_test.py @@ -25,8 +25,7 @@ from mujoco_warp import DisableBit from mujoco_warp import IntegratorType from mujoco_warp import test_data - -from . import inverse +from mujoco_warp._src import inverse def _assert_eq(a, b, name): diff --git a/mujoco_warp/_src/io_jax_test.py b/mujoco_warp/_src/io_jax_test.py index 84223d1a0..37d26e855 100644 --- a/mujoco_warp/_src/io_jax_test.py +++ b/mujoco_warp/_src/io_jax_test.py @@ -50,7 +50,7 @@ def _dims_match(test_obj, d1: Any, d2: Any, prefix: str = ""): _dims_match(test_obj, a1, a2, prefix + f1.name + ".") continue - if isinstance(f1.type, wp.types.array) or isinstance(f2.type, wp.types.array): + if isinstance(f1.type, wp.array) or isinstance(f2.type, wp.array): s1, s2 = a1.shape, a2.shape test_obj.assertEqual(len(s1), len(s2), f"{full_name} dims mismatch. Got {s1} and {s2}.") test_obj.assertEqual(s1, s2, f"{full_name} dims mismatch. Got {s1} and {s2}.") @@ -85,8 +85,8 @@ def _check_type_matches_annotation(test_obj, obj: Any, prefix: str = ""): test_obj.assertIsInstance(np_scalar_type(val), type_, msg.format(**locals())) continue - if isinstance(type_, wp.types.array): - test_obj.assertIsInstance(val, wp.types.array, msg.format(**locals())) + if isinstance(type_, wp.array): + test_obj.assertIsInstance(val, wp.array, msg.format(**locals())) continue origin_type = typing.get_origin(type_) @@ -107,8 +107,8 @@ def _check_type_matches_annotation(test_obj, obj: Any, prefix: str = ""): test_obj.assertIsInstance(np_scalar_type(val), type_, msg.format(**locals())) continue - if isinstance(type_, wp.types.array): - test_obj.assertIsInstance(val, wp.types.array, msg.format(**locals())) + if isinstance(type_, wp.array): + test_obj.assertIsInstance(val, wp.array, msg.format(**locals())) continue test_obj.assertEqual(type(val), type_, msg.format(**locals())) @@ -128,10 +128,10 @@ def _check_annotation_compat( if v in (int, bool, float): continue - if isinstance(v, wp.types.array): + if isinstance(v, wp.array): continue - if v in wp.types.vector_types: + if wp.types.type_is_composite(v): raise AssertionError(f"Vector types are not allowed. {info}") if typing.get_origin(v) == tuple and (in_cls or in_tuple): @@ -177,7 +177,7 @@ def _leading_dims_scale_w_nworld(test_obj, d1: Any, d2: Any, nworld1: int, nworl _leading_dims_scale_w_nworld(test_obj, a1, a2, nworld1, nworld2, prefix + f1.name + ".") continue - if isinstance(f1.type, wp.types.array) or isinstance(f2.type, wp.types.array): + if isinstance(f1.type, wp.array) or isinstance(f2.type, wp.array): s1, s2 = a1.shape[0], a2.shape[0] if s1 == s2: continue diff --git a/mujoco_warp/_src/math_test.py b/mujoco_warp/_src/math_test.py index f10ac3b85..e239ee522 100644 --- a/mujoco_warp/_src/math_test.py +++ b/mujoco_warp/_src/math_test.py @@ -16,9 +16,9 @@ import warp as wp from absl.testing import absltest -from .math import closest_segment_to_segment_points -from .math import upper_tri_index -from .math import upper_trid_index +from mujoco_warp._src.math import closest_segment_to_segment_points +from mujoco_warp._src.math import upper_tri_index +from mujoco_warp._src.math import upper_trid_index class ClosestSegmentSegmentPointsTest(absltest.TestCase): diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index 3b1ddd9a9..98dedd8ab 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -595,15 +595,15 @@ def _flex_elasticity( nedge = nvert * (nvert - 1) / 2 edges = wp.where( dim == 3, - wp.types.matrix(0, 1, 1, 2, 2, 0, 2, 3, 0, 3, 1, 3, shape=(6, 2), dtype=int), - wp.types.matrix(1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int), + wp.matrix(0, 1, 1, 2, 2, 0, 2, 3, 0, 3, 1, 3, shape=(6, 2), dtype=int), + wp.matrix(1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int), ) if timestep > 0.0 and not dsbl_damper: kD = flex_damping[f] / timestep else: kD = 0.0 - gradient = wp.types.matrix(0.0, shape=(6, 6)) + gradient = wp.matrix(0.0, shape=(6, 6)) for e in range(nedge): vert0 = flex_elem[(dim + 1) * elemid + edges[e, 0]] vert1 = flex_elem[(dim + 1) * elemid + edges[e, 1]] @@ -622,7 +622,7 @@ def _flex_elasticity( previous = deformed - vel * timestep elongation[e] = deformed * deformed - reference * reference + (deformed * deformed - previous * previous) * kD - metric = wp.types.matrix(0.0, shape=(6, 6)) + metric = wp.matrix(0.0, shape=(6, 6)) id = int(0) for ed1 in range(nedge): for ed2 in range(ed1, nedge): @@ -630,7 +630,7 @@ def _flex_elasticity( metric[ed2, ed1] = flex_stiffness[elemid, id] id += 1 - force = wp.types.matrix(0.0, shape=(6, 3)) + force = wp.matrix(0.0, shape=(6, 3)) for ed1 in range(nedge): for ed2 in range(nedge): for i in range(2): @@ -684,7 +684,7 @@ def _flex_bending( flex_vertadr[f] + flex_edgeflap[edgeid][1], ) - frc = wp.types.matrix(0.0, shape=(4, 3)) + frc = wp.matrix(0.0, shape=(4, 3)) if flex_bending[edgeid, 16]: v0 = flexvert_xpos_in[worldid, v[0]] v1 = flexvert_xpos_in[worldid, v[1]] @@ -695,7 +695,7 @@ def _flex_bending( frc[3] = wp.cross(v1 - v0, v2 - v0) frc[0] = -(frc[1] + frc[2] + frc[3]) - force = wp.types.matrix(0.0, shape=(nvert, 3)) + force = wp.matrix(0.0, shape=(nvert, 3)) for i in range(nvert): for x in range(3): for j in range(nvert): diff --git a/mujoco_warp/_src/ray_test.py b/mujoco_warp/_src/ray_test.py index 2a6e96033..9530db74b 100644 --- a/mujoco_warp/_src/ray_test.py +++ b/mujoco_warp/_src/ray_test.py @@ -21,8 +21,7 @@ import mujoco_warp as mjw from mujoco_warp import test_data - -from .types import vec6 +from mujoco_warp._src.types import vec6 # tolerance for difference between MuJoCo and MJX ray calculations - mostly # due to float precision diff --git a/mujoco_warp/_src/solver_test.py b/mujoco_warp/_src/solver_test.py index 5dfd27079..3eef0d806 100644 --- a/mujoco_warp/_src/solver_test.py +++ b/mujoco_warp/_src/solver_test.py @@ -25,8 +25,7 @@ from mujoco_warp import ConeType from mujoco_warp import SolverType from mujoco_warp import test_data - -from . import solver +from mujoco_warp._src import solver # tolerance for difference between MuJoCo and MJWarp solver calculations - mostly # due to float precision diff --git a/mujoco_warp/_src/support_test.py b/mujoco_warp/_src/support_test.py index 1d84d18b1..37c316780 100644 --- a/mujoco_warp/_src/support_test.py +++ b/mujoco_warp/_src/support_test.py @@ -25,10 +25,9 @@ from mujoco_warp import ConeType from mujoco_warp import State from mujoco_warp import test_data - -from .block_cholesky import create_blocked_cholesky_func -from .block_cholesky import create_blocked_cholesky_solve_func -from .warp_util import nested_kernel +from mujoco_warp._src.block_cholesky import create_blocked_cholesky_func +from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func +from mujoco_warp._src.warp_util import nested_kernel # tolerance for difference between MuJoCo and MJWarp support calculations - mostly # due to float precision diff --git a/mujoco_warp/_src/types_test.py b/mujoco_warp/_src/types_test.py index d8ac91ef8..b85b3ef59 100644 --- a/mujoco_warp/_src/types_test.py +++ b/mujoco_warp/_src/types_test.py @@ -22,9 +22,9 @@ from absl.testing import absltest from absl.testing import parameterized -from .types import Data -from .types import Model -from .types import Option +from mujoco_warp._src.types import Data +from mujoco_warp._src.types import Model +from mujoco_warp._src.types import Option class TypesTest(parameterized.TestCase): diff --git a/mujoco_warp/_src/unroll_test.py b/mujoco_warp/_src/unroll_test.py index 1b015dc3c..a1e62d1e4 100644 --- a/mujoco_warp/_src/unroll_test.py +++ b/mujoco_warp/_src/unroll_test.py @@ -30,9 +30,8 @@ import mujoco_warp as mjw from mujoco_warp import ConeType from mujoco_warp import test_data - -from .io import find_keys -from .io import make_trajectory +from mujoco_warp._src.io import find_keys +from mujoco_warp._src.io import make_trajectory class UnrollTest(parameterized.TestCase): diff --git a/mujoco_warp/_src/util_misc_test.py b/mujoco_warp/_src/util_misc_test.py index 935bef3d0..45925dd1c 100644 --- a/mujoco_warp/_src/util_misc_test.py +++ b/mujoco_warp/_src/util_misc_test.py @@ -22,10 +22,10 @@ from absl.testing import absltest from absl.testing import parameterized -from . import util_misc -from .types import MJ_MINVAL -from .types import WrapType -from .types import vec10 +from mujoco_warp._src import util_misc +from mujoco_warp._src.types import MJ_MINVAL +from mujoco_warp._src.types import WrapType +from mujoco_warp._src.types import vec10 def _assert_eq(a, b, name): diff --git a/mujoco_warp/_src/warp_util.py b/mujoco_warp/_src/warp_util.py index 1544ec63a..0301a386b 100644 --- a/mujoco_warp/_src/warp_util.py +++ b/mujoco_warp/_src/warp_util.py @@ -18,8 +18,6 @@ from typing import Callable, Optional import warp as wp -from warp._src.context import Module -from warp._src.context import get_module _STACK = None @@ -127,7 +125,7 @@ def nested_kernel( f: Optional[Callable] = None, *, enable_backward: Optional[bool] = None, - module: Optional[Module] = None, + module: Optional[wp.Module] = None, ): """Decorator to register a Warp kernel from a Python function. @@ -166,7 +164,7 @@ def my_kernel_with_args(a: wp.array(dtype=float), b: wp.array(dtype=float)): Args: f: The function to be registered as a kernel. enable_backward: If False, the backward pass will not be generated. - module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively, + module: The :class:`warp.Module` to which the kernel belongs. Alternatively, if a string `"unique"` is provided, the kernel is assigned to a new module named after the kernel name and hash. If None, the module is inferred from the function's module. @@ -182,7 +180,7 @@ def decorator(func): qualname = func.__qualname__ parts = [part for part in qualname.split(".") if part != ""] outer_functions = parts[:-1] - module_name = get_module(".".join([func.__module__] + outer_functions)) + module_name = wp.get_module(".".join([func.__module__] + outer_functions)) else: module_name = module @@ -220,8 +218,7 @@ def _hash_arg(a): def check_toolkit_driver(): - if wp._src.context.runtime is None: - wp._src.context.init() + wp.init() if wp.get_device().is_cuda: - if wp._src.context.runtime.toolkit_version < (12, 4) or wp._src.context.runtime.driver_version < (12, 4): + if not wp.is_conditional_graph_supported(): RuntimeError("Minimum supported CUDA version: 12.4.") diff --git a/pyproject.toml b/pyproject.toml index 1f8b8cd86..53b989e05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "etils[epath]", "mujoco>=3.3.7", "numpy", - "warp-lang>=1.9.1", + "warp-lang>=1.11.0", ] [[tool.uv.index]] @@ -55,7 +55,7 @@ dev = [ "pygls>=1.0.0,<2.0.0", "lsprotocol>=2023.0.1,<2024.0.0", "mujoco>=3.3.7.dev0", - "warp-lang>=1.9.1.dev0", + "warp-lang>=1.11.0.dev0", ] # TODO(team): cpu and cuda JAX optional dependencies are temporary, remove after we land MJX:Warp cpu = [ diff --git a/uv.lock b/uv.lock index 155eeb613..539993dd7 100644 --- a/uv.lock +++ b/uv.lock @@ -528,7 +528,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-xdist", marker = "extra == 'dev'" }, { name = "ruff", marker = "extra == 'dev'" }, - { name = "warp-lang", specifier = ">=1.9.1", index = "https://pypi.nvidia.com/" }, + { name = "warp-lang", specifier = ">=1.11.0", index = "https://pypi.nvidia.com/" }, { name = "warp-lang", marker = "extra == 'dev'", specifier = ">=1.9.1.dev0", index = "https://pypi.nvidia.com/" }, ] provides-extras = ["dev", "cpu", "cuda"] @@ -1239,17 +1239,17 @@ wheels = [ [[package]] name = "warp-lang" -version = "1.11.0.dev20251121" +version = "1.11.0" source = { registry = "https://pypi.nvidia.com/" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] wheels = [ - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0.dev20251121-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d7663ac8909b1c5ef9813d26fecb62dbbf9a5d48c4b69df2c4d99d92807c4779" }, - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0.dev20251121-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:572ef50d47e183399ae8f62e238fd14849178ef84178bc6c5f84f41bd1b361e8" }, - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0.dev20251121-py3-none-manylinux_2_34_aarch64.whl", hash = "sha256:c8d127bcca7e6f5207a331e5a861836b61fb28ea9271eb727ee304c79c3a9c10" }, - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0.dev20251121-py3-none-win_amd64.whl", hash = "sha256:e5117fe029cef773779d0821a20edddbd501362e32e95f79be82cc38f5137612" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a4f1c9a6e721d7de7d6dad6b242c54afaf20c6e14a767c0da03e5e963fcc13c" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:524dce20de6162ba25333552168ebf430973050e00d9f8116b8df41a60d25d6e" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0-py3-none-manylinux_2_34_aarch64.whl", hash = "sha256:1ae6cfc226107f96e4d495b41a3dab32488e8ee8f074b0e1bcaf22e7fb8c904d" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.0-py3-none-win_amd64.whl", hash = "sha256:80d8493cbe243a3510134f3af289646d7bd7484217a30ecf565d676466ef8a5e" }, ] [[package]]