diff --git a/CHANGELOG.md b/CHANGELOG.md index e96c875e6..60e956128 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - Add per-module option to disable fused floating point operations, use `wp.set_module_options({"fuse_fp": False})` ([GH-379](https://github.com/NVIDIA/warp/issues/379)). - Add per-module option to add CUDA-C line information for profiling, use `wp.set_module_options({"lineinfo": True})`. +- Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication. ### Changed diff --git a/warp/examples/optim/example_walker.py b/warp/examples/optim/example_walker.py index 7a51b9d21..1d77f4b37 100644 --- a/warp/examples/optim/example_walker.py +++ b/warp/examples/optim/example_walker.py @@ -15,6 +15,11 @@ # simulated forward in time and then evaluated based on the center of mass # momentum of the mesh. # +# This example uses the deprecated wp.matmul() for matrix multiplication, +# which will be removed in a future version. See the updated version of +# this example, example_tile_walker.py, in examples/tile for the new +# approach to GEMMs using Warp's tile API. +# ########################################################################### import math diff --git a/warp/examples/tile/example_tile_walker.py b/warp/examples/tile/example_tile_walker.py new file mode 100644 index 000000000..26debf808 --- /dev/null +++ b/warp/examples/tile/example_tile_walker.py @@ -0,0 +1,319 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +########################################################################### +# Example Tile Walker +# +# Trains a tetrahedral mesh quadruped to run. Feeds 8 time-varying input +# phases as inputs into a single layer fully connected network with a tanh +# activation function. Interprets the output of the network as tet +# activations, which are fed into the wp.sim soft mesh model. This is +# simulated forward in time and then evaluated based on the center of mass +# momentum of the mesh. +# +# This example uses the Warp tile API, which as of Warp 1.6 is the +# recommended way to handle matrix multiplication. example_walker.py in +# examples/optim demonstrates the old way of doing matrix multiplication, +# wp.matmul(), which will be deprecated in a future version. +# +########################################################################### + +import math +import os + +import numpy as np +from pxr import Gf, Usd, UsdGeom + +import warp as wp +import warp.examples +import warp.optim +import warp.sim +import warp.sim.render + +PHASE_COUNT = 8 +PHASE_STEP = wp.constant((2.0 * math.pi) / PHASE_COUNT) +PHASE_FREQ = wp.constant(5.0) +ACTIVATION_STRENGTH = wp.constant(0.3) + +TILE_TETS = wp.constant(8) +TILE_THREADS = 64 + + +@wp.kernel +def loss_kernel(com: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)): + tid = wp.tid() + vx = com[tid][0] + vy = com[tid][1] + vz = com[tid][2] + delta = wp.sqrt(vx * vx) + wp.sqrt(vy * vy) - vz + + wp.atomic_add(loss, 0, delta) + + +@wp.kernel +def com_kernel(velocities: wp.array(dtype=wp.vec3), n: int, com: wp.array(dtype=wp.vec3)): + tid = wp.tid() + v = velocities[tid] + a = v / wp.float32(n) + wp.atomic_add(com, 0, a) + + +@wp.kernel +def compute_phases(phases: wp.array(dtype=float), sim_time: float): + tid = wp.tid() + phases[tid] = wp.sin(PHASE_FREQ * sim_time + wp.float32(tid) * PHASE_STEP) + + +@wp.func +def tanh(x: float): + return wp.tanh(x) * ACTIVATION_STRENGTH + + +@wp.kernel +def network( + phases: wp.array2d(dtype=float), weights: wp.array2d(dtype=float), tet_activations: wp.array2d(dtype=float) +): + # output tile index + i = wp.tid() + + # GEMM + p = wp.tile_load(phases, 0, 0, m=PHASE_COUNT, n=1) + w = wp.tile_load(weights, i, 0, m=TILE_TETS, n=PHASE_COUNT) + out = wp.tile_matmul(w, p) + + # activation + activations = wp.tile_map(tanh, out) + wp.tile_store(tet_activations, i, 0, activations) + + +class Example: + def __init__(self, stage_path="example_tile_walker.usd", verbose=False, num_frames=300): + self.verbose = verbose + + fps = 60 + self.frame_dt = 1.0 / fps + self.num_frames = num_frames + + self.sim_substeps = 80 + self.sim_dt = self.frame_dt / self.sim_substeps + self.sim_time = 0.0 + + self.iter = 0 + self.train_rate = 0.025 + + self.phase_count = PHASE_COUNT + + self.render_time = 0.0 + + # bear + asset_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bear.usd")) + + geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/root/bear")) + points = geom.GetPointsAttr().Get() + + xform = Gf.Matrix4f(geom.ComputeLocalToWorldTransform(0.0)) + for i in range(len(points)): + points[i] = xform.Transform(points[i]) + + self.points = [wp.vec3(point) for point in points] + self.tet_indices = geom.GetPrim().GetAttribute("tetraIndices").Get() + + # sim model + builder = wp.sim.ModelBuilder() + builder.add_soft_mesh( + pos=wp.vec3(0.0, 0.5, 0.0), + rot=wp.quat_identity(), + scale=1.0, + vel=wp.vec3(0.0, 0.0, 0.0), + vertices=self.points, + indices=self.tet_indices, + density=1.0, + k_mu=2000.0, + k_lambda=2000.0, + k_damp=2.0, + tri_ke=0.0, + tri_ka=1e-8, + tri_kd=0.0, + tri_drag=0.0, + tri_lift=0.0, + ) + + # finalize model + self.model = builder.finalize(requires_grad=True) + self.control = self.model.control() + + self.model.soft_contact_ke = 2.0e3 + self.model.soft_contact_kd = 0.1 + self.model.soft_contact_kf = 10.0 + self.model.soft_contact_mu = 0.7 + + radii = wp.zeros(self.model.particle_count, dtype=float) + radii.fill_(0.05) + self.model.particle_radius = radii + self.model.ground = True + + # allocate sim states + self.states = [] + for _i in range(self.num_frames * self.sim_substeps + 1): + self.states.append(self.model.state(requires_grad=True)) + + # initialize the integrator. + self.integrator = wp.sim.SemiImplicitIntegrator() + + # model input + self.phases = [] + for _i in range(self.num_frames): + self.phases.append(wp.zeros(self.phase_count, dtype=float, requires_grad=True)) + + # weights matrix for linear network + rng = np.random.default_rng(42) + k = 1.0 / self.phase_count + weights = rng.uniform(-np.sqrt(k), np.sqrt(k), (self.model.tet_count, self.phase_count)) + self.weights = wp.array(weights, dtype=float, requires_grad=True) + + # tanh activation layer array + self.tet_activations = [] + for _i in range(self.num_frames): + self.tet_activations.append(wp.zeros(self.model.tet_count, dtype=float, requires_grad=True)) + + # optimization + self.loss = wp.zeros(1, dtype=float, requires_grad=True) + self.coms = [] + for _i in range(self.num_frames): + self.coms.append(wp.zeros(1, dtype=wp.vec3, requires_grad=True)) + self.optimizer = warp.optim.Adam([self.weights.flatten()], lr=self.train_rate) + + # rendering + if stage_path: + self.renderer = wp.sim.render.SimRenderer(self.model, stage_path) + else: + self.renderer = None + + # capture forward/backward passes + self.use_cuda_graph = wp.get_device().is_cuda + if self.use_cuda_graph: + with wp.ScopedCapture() as capture: + self.tape = wp.Tape() + with self.tape: + for i in range(self.num_frames): + self.forward(i) + self.tape.backward(self.loss) + self.graph = capture.graph + + def forward(self, frame): + with wp.ScopedTimer("network", active=self.verbose): + # build sinusoidal input phases + wp.launch(kernel=compute_phases, dim=self.phase_count, inputs=[self.phases[frame], self.sim_time]) + + # apply linear network with tanh activation + wp.launch_tiled( + kernel=network, + dim=math.ceil(self.model.tet_count / TILE_TETS), + inputs=[self.phases[frame].reshape((self.phase_count, 1)), self.weights], + outputs=[self.tet_activations[frame].reshape((self.model.tet_count, 1))], + block_dim=TILE_THREADS, + ) + self.control.tet_activations = self.tet_activations[frame] + + with wp.ScopedTimer("simulate", active=self.verbose): + # run simulation loop + for i in range(self.sim_substeps): + self.states[frame * self.sim_substeps + i].clear_forces() + self.integrator.simulate( + self.model, + self.states[frame * self.sim_substeps + i], + self.states[frame * self.sim_substeps + i + 1], + self.sim_dt, + self.control, + ) + self.sim_time += self.sim_dt + + with wp.ScopedTimer("loss", active=self.verbose): + # compute center of mass velocity + wp.launch( + com_kernel, + dim=self.model.particle_count, + inputs=[ + self.states[(frame + 1) * self.sim_substeps].particle_qd, + self.model.particle_count, + self.coms[frame], + ], + outputs=[], + ) + # compute loss + wp.launch(loss_kernel, dim=1, inputs=[self.coms[frame], self.loss], outputs=[]) + + def step(self): + with wp.ScopedTimer("step"): + if self.use_cuda_graph: + wp.capture_launch(self.graph) + else: + self.tape = wp.Tape() + with self.tape: + for i in range(self.num_frames): + self.forward(i) + self.tape.backward(self.loss) + + # optimization + x = self.weights.grad.flatten() + self.optimizer.step([x]) + + loss = self.loss.numpy() + if self.verbose: + print(f"Iteration {self.iter}: {loss}") + + # reset sim + self.sim_time = 0.0 + self.states[0] = self.model.state(requires_grad=True) + + # clear grads and zero arrays for next iteration + self.tape.zero() + self.loss.zero_() + for i in range(self.num_frames): + self.coms[i].zero_() + + self.iter += 1 + + def render(self): + if self.renderer is None: + return + + with wp.ScopedTimer("render"): + for i in range(self.num_frames + 1): + self.renderer.begin_frame(self.render_time) + self.renderer.render(self.states[i * self.sim_substeps]) + self.renderer.end_frame() + + self.render_time += self.frame_dt + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.") + parser.add_argument( + "--stage_path", + type=lambda x: None if x == "None" else str(x), + default="example_tile_walker.usd", + help="Path to the output USD file.", + ) + parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames per training iteration.") + parser.add_argument("--train_iters", type=int, default=30, help="Total number of training iterations.") + parser.add_argument("--verbose", action="store_true", help="Print out additional status messages during execution.") + + args = parser.parse_known_args()[0] + + with wp.ScopedDevice(args.device): + example = Example(stage_path=args.stage_path, verbose=args.verbose, num_frames=args.num_frames) + + for _ in range(args.train_iters): + example.step() + example.render() + + if example.renderer: + example.renderer.save()