Skip to content

Commit

Permalink
Merge branch 'lwawrzyniak/fix-overload-caching' into 'main'
Browse files Browse the repository at this point in the history
Fix function overload caching

See merge request omniverse/warp!556
  • Loading branch information
mmacklin committed Jun 6, 2024
2 parents 7829243 + e515ee3 commit 5cb4670
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- Fix to forward `wp.copy()` params to gradient and adjoint copy function calls.
- Fix so that `wp.randn()` doesn't return inf
- Fix slicing of arrays with gradients in kernels
- Fix function overload caching: ensure module is rebuilt if any function overloads are modified.

## [1.1.1] - 2024-05-24

Expand Down
49 changes: 25 additions & 24 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,32 +1487,33 @@ def hash_recursive(module, visited):
ch.update(bytes(s, "utf-8"))

# functions source
for func in module.functions.values():
s = func.adj.source
ch.update(bytes(s, "utf-8"))

if func.custom_grad_func:
s = func.custom_grad_func.adj.source
ch.update(bytes(s, "utf-8"))
if func.custom_replay_func:
s = func.custom_replay_func.adj.source
if func.replay_snippet:
s = func.replay_snippet
if func.native_snippet:
s = func.native_snippet
ch.update(bytes(s, "utf-8"))
if func.adj_native_snippet:
s = func.adj_native_snippet
ch.update(bytes(s, "utf-8"))

# cache func arg types
for arg, arg_type in func.adj.arg_types.items():
s = f"{arg}: {get_type_name(arg_type)}"
for function in module.functions.values():
# include all overloads
for sig, func in function.user_overloads.items():
# signature
ch.update(bytes(sig, "utf-8"))

# source
s = func.adj.source
ch.update(bytes(s, "utf-8"))

# Populate constants referenced in this function
if func.adj:
module.constants.update(func.adj.get_constant_references())
if func.custom_grad_func:
s = func.custom_grad_func.adj.source
ch.update(bytes(s, "utf-8"))
if func.custom_replay_func:
s = func.custom_replay_func.adj.source
if func.replay_snippet:
s = func.replay_snippet
if func.native_snippet:
s = func.native_snippet
ch.update(bytes(s, "utf-8"))
if func.adj_native_snippet:
s = func.adj_native_snippet
ch.update(bytes(s, "utf-8"))

# Populate constants referenced in this function
if func.adj:
module.constants.update(func.adj.get_constant_references())

# kernel source
for kernel in module.kernels.values():
Expand Down
2 changes: 1 addition & 1 deletion warp/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def test_gradient_slice_2d(test, device):
wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1])

# use internal gradients (.grad), adj_inputs are None
wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None])
wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None, 1])

assert_np_equal(a.grad.numpy(), np.array([[0.0, 0.0], [6.0, 8.0], [0.0, 0.0]]))

Expand Down
111 changes: 111 additions & 0 deletions warp/tests/test_module_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# TODO: add more tests for kernels and generics

import os
import tempfile
import unittest
from importlib import util

import warp as wp
from warp.tests.unittest_utils import *

FUNC_OVERLOAD_1 = """# -*- coding: utf-8 -*-
import warp as wp
@wp.func
def fn():
wp.print(17)
@wp.func
def fn(value: int):
wp.print(value)
"""

# should be same hash as FUNC_OVERLOAD_1
FUNC_OVERLOAD_2 = """# -*- coding: utf-8 -*-
import warp as wp
@wp.func
def fn():
wp.print(17)
@wp.func
def fn(value: int):
wp.print(value)
"""

# should be different hash than FUNC_OVERLOAD_1 (first overload is different)
FUNC_OVERLOAD_3 = """# -*- coding: utf-8 -*-
import warp as wp
@wp.func
def fn():
wp.print(42)
@wp.func
def fn(value: int):
wp.print(value)
"""

# should be different hash than FUNC_OVERLOAD_1 (second overload is different)
FUNC_OVERLOAD_4 = """# -*- coding: utf-8 -*-
import warp as wp
@wp.func
def fn():
wp.print(17)
@wp.func
def fn(value: int):
wp.print(value + 1)
"""


def load_code_as_module(code, name):
file, file_path = tempfile.mkstemp(suffix=".py")

try:
with os.fdopen(file, "w") as f:
f.write(code)

spec = util.spec_from_file_location(name, file_path)
module = util.module_from_spec(spec)
spec.loader.exec_module(module)
finally:
os.remove(file_path)

return wp.get_module(module.__name__)


def test_function_overload_hashing(test, device):
m1 = load_code_as_module(FUNC_OVERLOAD_1, "func_overload_1")
m2 = load_code_as_module(FUNC_OVERLOAD_2, "func_overload_2")
m3 = load_code_as_module(FUNC_OVERLOAD_3, "func_overload_3")
m4 = load_code_as_module(FUNC_OVERLOAD_4, "func_overload_4")

hash1 = m1.hash_module()
hash2 = m2.hash_module()
hash3 = m3.hash_module()
hash4 = m4.hash_module()

test.assertEqual(hash2, hash1)
test.assertNotEqual(hash3, hash1)
test.assertNotEqual(hash4, hash1)


class TestModuleHashing(unittest.TestCase):
pass


add_function_test(TestModuleHashing, "test_function_overload_hashing", test_function_overload_hashing)


if __name__ == "__main__":
wp.build.clear_kernel_cache()
unittest.main(verbosity=2)
4 changes: 4 additions & 0 deletions warp/tests/unittest_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
from warp.tests.test_mesh_query_ray import TestMeshQueryRay
from warp.tests.test_mlp import TestMLP
from warp.tests.test_model import TestModel
from warp.tests.test_module_hashing import TestModuleHashing
from warp.tests.test_modules_lite import TestModuleLite
from warp.tests.test_multigpu import TestMultiGPU
from warp.tests.test_noise import TestNoise
Expand Down Expand Up @@ -234,6 +235,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
TestMeshQueryRay,
TestMLP,
TestModel,
TestModuleHashing,
TestModuleLite,
TestMultiGPU,
TestNoise,
Expand Down Expand Up @@ -300,6 +302,7 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
from warp.tests.test_mesh_query_aabb import TestMeshQueryAABBMethods
from warp.tests.test_mesh_query_point import TestMeshQueryPoint
from warp.tests.test_mesh_query_ray import TestMeshQueryRay
from warp.tests.test_module_hashing import TestModuleHashing
from warp.tests.test_modules_lite import TestModuleLite
from warp.tests.test_noise import TestNoise
from warp.tests.test_operators import TestOperators
Expand Down Expand Up @@ -342,6 +345,7 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
TestMeshQueryAABBMethods,
TestMeshQueryPoint,
TestMeshQueryRay,
TestModuleHashing,
TestModuleLite,
TestNoise,
TestOperators,
Expand Down

0 comments on commit 5cb4670

Please sign in to comment.