Skip to content

Commit

Permalink
Merge branch 'copy_fix' into 'main'
Browse files Browse the repository at this point in the history
Copy fix

See merge request omniverse/warp!551
  • Loading branch information
daedalus5 committed Jun 6, 2024
2 parents d52bf7a + 42cdb35 commit 609d0de
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- Improve validation of user-provided fields and values in warp.fem
- Support headless rendering of `OpenGLRenderer` via `pyglet.options["headless"] = True`
- `RegisteredGLBuffer` can fall back to CPU-bound copying if CUDA/OpenGL interop is not available
- Fix to forward `wp.copy()` params to gradient and adjoint copy function calls.

## [1.1.1] - 2024-05-24

Expand Down
15 changes: 11 additions & 4 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5058,21 +5058,28 @@ def copy(

# copy gradient, if needed
if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None:
copy(dest.grad, src.grad, stream=stream)
copy(dest.grad, src.grad, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream)

if runtime.tape:
runtime.tape.record_func(backward=lambda: adj_copy(dest.grad, src.grad, stream=stream), arrays=[dest, src])
runtime.tape.record_func(
backward=lambda: adj_copy(
dest.grad, src.grad, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream
),
arrays=[dest, src],
)


def adj_copy(adj_dest: warp.array, adj_src: warp.array, stream: Stream = None):
def adj_copy(
adj_dest: warp.array, adj_src: warp.array, dest_offset: int, src_offset: int, count: int, stream: Stream = None
):
"""Copy adjoint operation for wp.copy() calls on the tape.
Args:
adj_dest: Destination array adjoint
adj_src: Source array adjoint
stream: The stream on which the copy was performed in the forward pass
"""
copy(adj_src, adj_dest, stream=stream)
copy(adj_src, adj_dest, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream)


def type_str(t):
Expand Down

0 comments on commit 609d0de

Please sign in to comment.