Skip to content

Commit

Permalink
Merge branch 'GH-394' into 'main'
Browse files Browse the repository at this point in the history
Non-Warp arrays passed to Warp kernels are unaffected by Tape.backward()

See merge request omniverse/warp!948
  • Loading branch information
daedalus5 committed Jan 17, 2025
2 parents b7c9d2a + a7d4d54 commit 21ff033
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
([GH-412](https://github.com/NVIDIA/warp/issues/412)).
- Fix scale and rotation issues with the rock geometry used in the granular collision SDF example
([GH-409](https://github.com/NVIDIA/warp/issues/409)).
- Fix unintended modification of non-Warp arrays during the backward pass ([GH-394](https://github.com/NVIDIA/warp/issues/394)).
- Fix so that `wp.Tape.zero()` zeroes gradients passed via the 'grads' parameter in `wp.Tape.backward()` ([GH-407](https://github.com/NVIDIA/warp/issues/407)).
- Fix the OpenGL renderer not working when multiple instances exist at the same time ([GH-385](https://github.com/NVIDIA/warp/issues/385)).
- Negative constants evaluate to compile-time constants (fixes [GH-403](https://github.com/NVIDIA/warp/issues/403))
Expand Down
9 changes: 6 additions & 3 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5094,6 +5094,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):

# try to convert to a value type (vec3, mat33, etc)
elif issubclass(arg_type, ctypes.Array):
# simple value types don't have gradient arrays, but native built-in signatures still expect a non-null adjoint value of the correct type
if value is None and adjoint:
return arg_type(0)
if warp.types.types_equal(type(value), arg_type):
return value
else:
Expand All @@ -5103,9 +5106,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
except Exception as e:
raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}") from e

elif isinstance(value, bool):
return ctypes.c_bool(value)

elif isinstance(value, arg_type):
try:
# try to pack as a scalar type
Expand All @@ -5120,6 +5120,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
) from e

else:
# scalar args don't have gradient arrays, but native built-in signatures still expect a non-null scalar adjoint
if value is None and adjoint:
return arg_type._type_(0)
try:
# try to pack as a scalar type
if arg_type is warp.types.float16:
Expand Down
6 changes: 3 additions & 3 deletions warp/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def _check_kernel_array_access(self, kernel, args):
# returns the adjoint of a kernel parameter
def get_adjoint(self, a):
if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):
# if input is a simple type (e.g.: float, vec3, etc) then
# no gradient needed (we only return gradients through arrays and structs)
return a
# if input is a simple type (e.g.: float, vec3, etc) or a non-Warp array,
# then no gradient needed (we only return gradients through Warp arrays and structs)
return None

elif wp.types.is_array(a) and a.grad:
# keep track of all gradients used by the tape (for zeroing)
Expand Down
35 changes: 35 additions & 0 deletions warp/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,38 @@ def test_cuda_array_interface(test, device):
assert a1.strides == a2.strides


@wp.kernel
def vec_sum_kernel(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=wp.vec3), z: wp.array(dtype=wp.vec3)):
tid = wp.tid()
z[tid] = x[tid] + y[tid]


# ensure torch arrays passed to Warp kernels are unchanged by Tape.backward()
def test_tensor_in_warp_kernel(test, device):
torch_device = wp.device_to_torch(device)

x = torch.ones((10, 3), dtype=torch.float32, device=torch_device)
y = torch.ones((10, 3), dtype=torch.float32, device=torch_device)
wp_y = wp.from_torch(y, dtype=wp.vec3, requires_grad=True)
z = torch.zeros((10, 3), dtype=torch.float32, device=torch_device)
wp_z = wp.from_torch(z, dtype=wp.vec3, requires_grad=True)

tape = wp.Tape()

with tape:
wp.launch(vec_sum_kernel, dim=10, inputs=[x, wp_y], outputs=[wp_z], device=device)

assert_np_equal(x.cpu().numpy(), np.ones((10, 3), dtype=float))

tape.backward(grads={wp_z: wp.ones_like(wp_z)})

# x is unchanged by Tape.backward()
assert_np_equal(x.cpu().numpy(), np.ones((10, 3), dtype=float))

# we can still compute the gradient of y because Warp created an array for it
assert_np_equal(y.grad.cpu().numpy(), np.ones((10, 3), dtype=float))


def test_to_torch(test, device):
import torch

Expand Down Expand Up @@ -913,6 +945,9 @@ class TestTorch(unittest.TestCase):
add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
add_function_test(TestTorch, "test_direct", test_direct, devices=torch_compatible_devices)
add_function_test(
TestTorch, "test_tensor_in_warp_kernel", test_tensor_in_warp_kernel, devices=torch_compatible_devices
)

if torch_compatible_cuda_devices:
add_function_test(
Expand Down

0 comments on commit 21ff033

Please sign in to comment.