Skip to content

Commit

Permalink
Merge branch 'ccrouzet/future-annotations' into 'main'
Browse files Browse the repository at this point in the history
Add Support for PEP 563

See merge request omniverse/warp!619
  • Loading branch information
christophercrouzet committed Jul 14, 2024
2 parents 943561d + e10a583 commit 31266bc
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- Improve error messages for unsupported constructs
- Update `wp.matmul()` CPU fallback to use dtype explicitly in `np.matmul()` call
- Fix ShapeInstancer `__new__()` method (missing instance return and `*args` parameter)
- Add support for PEP 563's `from __future__ import annotations`.

## [1.2.2] - 2024-07-04

Expand Down
90 changes: 86 additions & 4 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ast
import builtins
import ctypes
import functools
import inspect
import math
import re
Expand Down Expand Up @@ -98,13 +99,94 @@ def op_str_is_chainable(op: str) -> builtins.bool:
return op in comparison_chain_strings


def get_closure_cell_contents(obj):
"""Retrieve a closure's cell contents or `None` if it's empty."""
try:
return obj.cell_contents
except ValueError:
pass

return None


def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
if not annotations:
return {}

if not any(isinstance(x, str) for x in annotations.values()):
# No annotation to un-stringize.
return annotations

if isinstance(obj, type):
# class
globals = {}
module_name = getattr(obj, "__module__", None)
if module_name:
module = sys.modules.get(module_name, None)
if module:
globals = getattr(module, "__dict__", {})
locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
# module
globals = obj.__dict__
locals = {}
unwrap = None
elif callable(obj):
# function
globals = getattr(obj, "__globals__", {})
# Capture the variables from the surrounding scope.
closure_vars = zip(
obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
)
locals = {k: v for k, v in closure_vars if v is not None}
unwrap = obj
else:
raise TypeError(f"{obj!r} is not a module, class, or callable.")

if unwrap is not None:
while True:
if hasattr(unwrap, "__wrapped__"):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"):
globals = unwrap.__globals__

# "Inject" type parameters into the local namespace
# (unless they are shadowed by assignments *in* the local namespace),
# as a way of emulating annotation scopes when calling `eval()`
type_params = getattr(obj, "__type_params__", ())
if type_params:
locals = {param.__name__: param for param in type_params} | locals

return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}


def get_annotations(obj: Any) -> Mapping[str, Any]:
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
"""Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
# This backports `inspect.get_annotations()` for Python 3.9 and older.
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
if isinstance(obj, type):
return obj.__dict__.get("__annotations__", {})
annotations = obj.__dict__.get("__annotations__", {})
else:
annotations = getattr(obj, "__annotations__", {})

# Evaluating annotations can be done using the `eval_str` parameter with
# the official function from the `inspect` module.
return eval_annotations(annotations, obj)


return getattr(obj, "__annotations__", {})
def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
"""Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
# See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
spec = inspect.getfullargspec(func)
return spec._replace(annotations=eval_annotations(spec.annotations, func))


def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
Expand Down Expand Up @@ -698,7 +780,7 @@ def __init__(
adj.custom_reverse_num_input_args = custom_reverse_num_input_args

# parse argument types
argspec = inspect.getfullargspec(func)
argspec = get_full_arg_spec(func)

# ensure all arguments are annotated
if overload_annotations is None:
Expand Down
14 changes: 3 additions & 11 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def value_func(arg_types, arg_values):

def get_function_args(func):
"""Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
argspec = inspect.getfullargspec(func)
argspec = warp.codegen.get_full_arg_spec(func)

# use source-level argument annotations
if len(argspec.annotations) < len(argspec.args):
Expand Down Expand Up @@ -963,7 +963,7 @@ def overload(kernel, arg_types=None):
)

# ensure all arguments are annotated
argspec = inspect.getfullargspec(fn)
argspec = warp.codegen.get_full_arg_spec(fn)
if len(argspec.annotations) < len(argspec.args):
raise RuntimeError(f"Incomplete argument annotations on kernel overload {fn.__name__}")

Expand Down Expand Up @@ -1556,14 +1556,6 @@ def hash_module(self, recompute_content_hash=False):
computed ``content_hash`` will be used.
"""

def get_annotations(obj: Any) -> Mapping[str, Any]:
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
if isinstance(obj, type):
return obj.__dict__.get("__annotations__", {})

return getattr(obj, "__annotations__", {})

def get_type_name(type_hint):
if isinstance(type_hint, warp.codegen.Struct):
return get_type_name(type_hint.cls)
Expand All @@ -1585,7 +1577,7 @@ def hash_recursive(module, visited):
for struct in module.structs.values():
s = ",".join(
"{}: {}".format(name, get_type_name(type_hint))
for name, type_hint in get_annotations(struct.cls).items()
for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
)
ch.update(bytes(s, "utf-8"))

Expand Down
3 changes: 1 addition & 2 deletions warp/fem/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from typing import Any, Callable

import warp as wp
Expand All @@ -15,7 +14,7 @@ def __init__(self, func: Callable):
self.func = func
self.name = wp.codegen.make_full_qualified_name(self.func)
self.module = wp.get_module(self.func.__module__)
self.argspec = inspect.getfullargspec(self.func)
self.argspec = wp.codegen.get_full_arg_spec(self.func)


class Operator:
Expand Down
90 changes: 90 additions & 0 deletions warp/tests/test_future_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# This is what we are actually testing.
from __future__ import annotations

import unittest

import warp as wp
from warp.tests.unittest_utils import *


@wp.struct
class FooData:
x: float
y: float


class Foo:
Data = FooData

@wp.func
def compute():
pass


@wp.kernel
def kernel_1(
out: wp.array(dtype=float),
):
tid = wp.tid()


@wp.kernel
def kernel_2(
out: wp.array(dtype=float),
):
tid = wp.tid()
out[tid] = 1.23


def create_kernel_3(foo: Foo):
def fn(
data: foo.Data,
out: wp.array(dtype=float),
):
tid = wp.tid()

# Referencing a variable in a type hint like `foo.Data` isn't officially
# accepted by Python but it's still being used in some places (e.g.: `warp.fem`)
# where it works only because the variable being referenced within the function,
# which causes it to be promoted to a closure variable. Without that,
# it wouldn't be possible to resolve `foo` and to evaluate the `foo.Data`
# string to its corresponding type.
foo.compute()

out[tid] = data.x + data.y

return wp.Kernel(func=fn)


def test_future_annotations(test, device):
foo = Foo()
foo_data = FooData()
foo_data.x = 1.23
foo_data.y = 2.34

out = wp.empty(1, dtype=float)

kernel_3 = create_kernel_3(foo)

wp.launch(kernel_1, dim=out.shape, outputs=(out,))
wp.launch(kernel_2, dim=out.shape, outputs=(out,))
wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,))


class TestFutureAnnotations(unittest.TestCase):
pass


add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations)


if __name__ == "__main__":
wp.clear_kernel_cache()
unittest.main(verbosity=2)

0 comments on commit 31266bc

Please sign in to comment.