Skip to content

Commit

Permalink
Merge branch 'ershi/stub-improvements' into 'main'
Browse files Browse the repository at this point in the history
Auto-completion of wp.config variables

See merge request omniverse/warp!593
  • Loading branch information
shi-eric committed Jul 2, 2024
2 parents 38efa77 + 4146063 commit 15ab93a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 54 deletions.
12 changes: 7 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
## [Upcoming Release] - 2024-??-??

- Improve memory usage and performance for rigid body contact handling when `self.rigid_mesh_contact_max` is zero (default behavior)
- The `mask` argument to `wp.sim.eval_fk` now accepts both integer and bool arrays
- Support for NumPy >= 2.0
- Fix hashing of replay functions and snippets
- The `mask` argument to `wp.sim.eval_fk` now accepts both integer and boolean arrays.
- Support for NumPy >= 2.0.
- Fix hashing of replay functions and snippets.
- Add additional code comments for random number sampling functions in `rand.h`
- Add information to the module load printouts to indicate whether a module was
compiled `(compiled)`, loaded from the cache `(cached)`, or was unable to be
loaded `(error)`.
- `wp.config.verbose = True` now also prints out a message upon the entry to a `wp.ScopedTimer`.
- Add additional documentation and examples demonstrating wp.copy(), wp.clone(), and array.assign() differentiability
- Fix adding `__new__()` methods for all class `__del__()` methods to anticipate when a class instance is created but not instantiated before garbage collection
- Add additional documentation and examples demonstrating `wp.copy()`, `wp.clone()`, and `array.assign()` differentiability.
- Fix adding `__new__()` methods for all class `__del__()` methods to
anticipate when a class instance is created but not instantiated before garbage collection.
- Add code-completion support for wp.config variables.

## [1.2.1] - 2024-06-14

Expand Down
4 changes: 2 additions & 2 deletions warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,6 @@

from . import builtins

import warp.config
import warp.config as config

__version__ = warp.config.version
__version__ = config.version
62 changes: 45 additions & 17 deletions warp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,61 @@
from typing import Optional

version: str = "1.2.1"
"""Warp version string"""

verify_fp: bool = False # verify inputs and outputs are finite after each launch
verify_cuda: bool = False # if true will check CUDA errors after each kernel launch / memory operation
print_launches: bool = False # if true will print out launch information

verify_fp: bool = False
"""If `True`, Warp will check that inputs and outputs are finite before and/or after various operations.
Has performance implications.
"""

verify_cuda: bool = False
"""If `True`, Warp will check for CUDA errors after every launch and memory operation.
CUDA error verification cannot be used during graph capture. Has performance implications.
"""

print_launches: bool = False
"""If `True`, Warp will print details of every kernel launch to standard out
(e.g. launch dimensions, inputs, outputs, device, etc.). Has performance implications
"""

mode: str = "release"
verbose: bool = False # print extra informative messages
verbose_warnings: bool = False # whether file and line info gets included in Warp warnings
quiet: bool = False # suppress all output except errors and warnings
"""Controls whether to compile Warp kernels in debug or release mode.
Valid choices are `"release"` or `"debug"`. Has performance implications.
"""

verbose: bool = False
"""If `True`, additional information will be printed to standard out during code generation, compilation, etc."""

verbose_warnings: bool = False
"""If `True`, Warp warnings will include extra information such as the source file and line number."""

quiet: bool = False
"""Suppress all output except errors and warnings."""

cache_kernels: bool = True
kernel_cache_dir: Optional[str] = None # path to kernel cache directory, if None a default path will be used
"""If `True`, kernels that have already been compiled from previous application launches will not be recompiled."""

kernel_cache_dir: Optional[str] = None
"""Path to kernel cache directory, if `None`, a default path will be used."""

cuda_output: Optional[str] = (
None # preferred CUDA output format for kernels ("ptx" or "cubin"), determined automatically if unspecified
)
cuda_output: Optional[str] = None
"""Preferred CUDA output format for kernels (`"ptx"` or `"cubin"`), determined automatically if unspecified"""

ptx_target_arch: int = 75 # target architecture for PTX generation, defaults to the lowest architecture that supports all of Warp's features
ptx_target_arch: int = 75
"""Target architecture for PTX generation, defaults to the lowest architecture that supports all of Warp's features."""

enable_backward: bool = True # whether to compiler the backward passes of the kernels
enable_backward: bool = True
"""Whether to compiler the backward passes of the kernels."""

llvm_cuda: bool = False # use Clang/LLVM instead of NVRTC to compile CUDA
llvm_cuda: bool = False
"""Use Clang/LLVM instead of NVRTC to compile CUDA."""

enable_graph_capture_module_load_by_default: bool = (
True # Default value of force_module_load for capture_begin() if CUDA driver does not support at least CUDA 12.3
)
enable_graph_capture_module_load_by_default: bool = True
"""Default value of `force_module_load` for `capture_begin()` if CUDA driver does not support at least CUDA 12.3."""

enable_mempools_at_init: bool = True # Whether CUDA devices will be initialized with mempools enabled (if supported)
enable_mempools_at_init: bool = True
"""Whether CUDA devices will be initialized with mempools enabled (if supported)."""

max_unroll: int = 16
"""Maximum unroll factor for loops."""
43 changes: 15 additions & 28 deletions warp/fem/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,11 @@ def _resolve_path(func, node):
return None, path


def _path_to_ast_attribute(name: str) -> ast.Attribute:
path = name.split(".")
path.reverse()

node = ast.Name(id=path.pop(), ctx=ast.Load())
while len(path):
node = ast.Attribute(
value=node,
attr=path.pop(),
ctx=ast.Load(),
)
return node


class IntegrandTransformer(ast.NodeTransformer):
def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike], annotations: Dict[str, Any]):
self._integrand = integrand
self._field_args = field_args
self._annotations = annotations

def visit_Call(self, call: ast.Call):
call = self.generic_visit(call)
Expand All @@ -85,18 +72,15 @@ def visit_Call(self, call: ast.Call):
# Shortcut for evaluating fields as f(x...)
field = self._field_args[callee]

arg_type = self._integrand.argspec.annotations[callee]
operator = arg_type.call_operator
# Replace with default call operator
abstract_arg_type = self._integrand.argspec.annotations[callee]
default_operator = abstract_arg_type.call_operator
concrete_arg_type = self._annotations[callee]
self._replace_call_func(call, concrete_arg_type, default_operator, field)

call.func = ast.Attribute(
value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
attr="call_operator",
ctx=ast.Load(),
)
# insert callee as first argument
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args

self._replace_call_func(call, operator, field)

return call

func, _ = _resolve_path(self._integrand.func, call.func)
Expand All @@ -106,7 +90,7 @@ def visit_Call(self, call: ast.Call):
callee = getattr(call.args[0], "id", None)
if callee in self._field_args:
field = self._field_args[callee]
self._replace_call_func(call, func, field)
self._replace_call_func(call, func, func, field)

if isinstance(func, Integrand):
key = self._translate_callee(func, call.args)
Expand All @@ -120,12 +104,15 @@ def visit_Call(self, call: ast.Call):

return call

def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
def _replace_call_func(self, call: ast.Call, callee: Union[type, Operator], operator: Operator, field: FieldLike):
try:
# Retrieve the function pointer corresponding to the operator implementation for the field type
pointer = operator.resolver(field)
setattr(operator, pointer.key, pointer)
except AttributeError as e:
raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}") from e
# Save the pointer as an attribute than can be accessed from the callee scope
setattr(callee, pointer.key, pointer)
# Update the ast Call node to use the new function pointer
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())

def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
Expand Down Expand Up @@ -162,7 +149,7 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
annotations[arg] = arg_type

# Transform field evaluation calls
transformer = IntegrandTransformer(integrand, field_args)
transformer = IntegrandTransformer(integrand, field_args, annotations)

suffix = "_".join([f.name for f in field_args.values()])

Expand Down
4 changes: 2 additions & 2 deletions warp/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@

from . import builtins

import warp.config
import warp.config as config

__version__ = warp.config.version
__version__ = config.version


@over
Expand Down

0 comments on commit 15ab93a

Please sign in to comment.