Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse dunder methods in interface file #163

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions gtwrap/interface_parser/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@

from typing import Any, Iterable, List, Union

from pyparsing import Literal, Optional, ZeroOrMore # type: ignore
from pyparsing import ZeroOrMore # type: ignore
from pyparsing import Literal, Optional, Word, alphas

from .enum import Enum
from .function import ArgumentList, ReturnType
from .template import Template
from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, OPERATOR,
RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL)
from .tokens import (CLASS, COLON, CONST, DUNDER, IDENT, LBRACE, LPAREN,
OPERATOR, RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL)
from .type import TemplatedType, Typename
from .utils import collect_namespaces
from .variable import Variable
Expand Down Expand Up @@ -212,6 +213,26 @@ def __repr__(self) -> str:
)


class DunderMethod:
"""Special Python double-underscore (dunder) methods, e.g. __iter__, __contains__"""
rule = (
DUNDER #
+ (Word(alphas))("name") #
+ DUNDER #
+ LPAREN #
+ ArgumentList.rule("args_list") #
+ RPAREN #
+ SEMI_COLON # BR
).setParseAction(lambda t: DunderMethod(t.name, t.args_list))

def __init__(self, name: str, args: ArgumentList):
self.name = name
self.args = args

def __repr__(self) -> str:
return f"DunderMethod: __{self.name}__({self.args})"


class Class:
"""
Rule to parse a class defined in the interface file.
Expand All @@ -223,23 +244,26 @@ class Hello {
};
```
"""

class Members:
"""
Rule for all the members within a class.
"""
rule = ZeroOrMore(Constructor.rule #
rule = ZeroOrMore(DunderMethod.rule #
^ Constructor.rule #
^ Method.rule #
^ StaticMethod.rule #
^ Variable.rule #
^ Operator.rule #
^ Enum.rule #
).setParseAction(lambda t: Class.Members(t.asList()))

def __init__(self,
members: List[Union[Constructor, Method, StaticMethod,
Variable, Operator]]):
def __init__(self, members: List[Union[Constructor, Method,
StaticMethod, Variable,
Operator, Enum, DunderMethod]]):
self.ctors = []
self.methods = []
self.dunder_methods = []
self.static_methods = []
self.properties = []
self.operators = []
Expand All @@ -251,6 +275,8 @@ def __init__(self,
self.methods.append(m)
elif isinstance(m, StaticMethod):
self.static_methods.append(m)
elif isinstance(m, DunderMethod):
self.dunder_methods.append(m)
elif isinstance(m, Variable):
self.properties.append(m)
elif isinstance(m, Operator):
Expand All @@ -271,8 +297,8 @@ def __init__(self,
+ SEMI_COLON # BR
).setParseAction(lambda t: Class(
t.template, t.is_virtual, t.name, t.parent_class, t.members.ctors, t.
members.methods, t.members.static_methods, t.members.properties, t.
members.operators, t.members.enums))
members.methods, t.members.static_methods, t.members.dunder_methods, t.
members.properties, t.members.operators, t.members.enums))

def __init__(
self,
Expand All @@ -283,6 +309,7 @@ def __init__(
ctors: List[Constructor],
methods: List[Method],
static_methods: List[StaticMethod],
dunder_methods: List[DunderMethod],
properties: List[Variable],
operators: List[Operator],
enums: List[Enum],
Expand All @@ -308,6 +335,7 @@ def __init__(
self.ctors = ctors
self.methods = methods
self.static_methods = static_methods
self.dunder_methods = dunder_methods
self.properties = properties
self.operators = operators
self.enums = enums
Expand All @@ -326,6 +354,8 @@ def __init__(
method.parent = self
for static_method in self.static_methods:
static_method.parent = self
for dunder_method in self.dunder_methods:
dunder_method.parent = self
for _property in self.properties:
_property.parent = self

Expand Down
2 changes: 1 addition & 1 deletion gtwrap/interface_parser/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_parse_result(parse_result: ParseResults):
return ArgumentList([])

def __repr__(self) -> str:
return repr(tuple(self.args_list))
return ",".join([repr(x) for x in self.args_list])

def __len__(self) -> int:
return len(self.args_list)
Expand Down
1 change: 1 addition & 0 deletions gtwrap/interface_parser/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

LPAREN, RPAREN, LBRACE, RBRACE, COLON, SEMI_COLON = map(Suppress, "(){}:;")
LOPBRACK, ROPBRACK, COMMA, EQUAL = map(Suppress, "<>,=")
DUNDER = Suppress(Literal("__"))

# Default argument passed to functions/methods.
# Allow anything up to ',' or ';' except when they
Expand Down
64 changes: 64 additions & 0 deletions gtwrap/pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self,
'continue', 'global', 'pass'
]

self.dunder_methods = ('len', 'contains', 'iter')

# amount of indentation to add before each function/method declaration.
self.method_indent = '\n' + (' ' * 8)

Expand Down Expand Up @@ -153,6 +155,51 @@ def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str,
suffix=suffix)
return ret

def _wrap_dunder(self,
method,
cpp_class,
prefix,
suffix,
method_suffix=""):
"""
Wrap a Python double-underscore (dunder) method.

E.g. __len__() gets wrapped as `.def("__len__", [](gtsam::KeySet* self) {return self->size();})`

Supported methods are:
- __contains__(T x)
- __len__()
- __iter__()
"""
py_method = method.name + method_suffix
args_names = method.args.names()
py_args_names = self._py_args_names(method.args)
args_signature_with_names = self._method_args_signature(method.args)

if method.name == 'len':
function_call = "return self->size();"
elif method.name == 'contains':
function_call = f"return self->find({method.args.args_list[0].name}) != self->end();"
elif method.name == 'iter':
function_call = "return py::make_iterator(self->begin(), self->end());"
gchenfc marked this conversation as resolved.
Show resolved Hide resolved

ret = ('{prefix}.def("__{py_method}__",'
'[]({self}{opt_comma}{args_signature_with_names}){{'
'{function_call}'
'}}'
'{py_args_names}){suffix}'.format(
prefix=prefix,
py_method=py_method,
self=f"{cpp_class}* self",
opt_comma=', ' if args_names else '',
args_signature_with_names=args_signature_with_names,
function_call=function_call,
py_args_names=py_args_names,
suffix=suffix,
))

return ret

def _wrap_method(self,
method,
cpp_class,
Expand Down Expand Up @@ -235,6 +282,20 @@ def _wrap_method(self,

return ret

def wrap_dunder_methods(self,
methods,
cpp_class,
prefix='\n' + ' ' * 8,
suffix=''):
res = ""
for method in methods:
res += self._wrap_dunder(method=method,
cpp_class=cpp_class,
prefix=prefix,
suffix=suffix)

return res

def wrap_methods(self,
methods,
cpp_class,
Expand Down Expand Up @@ -398,6 +459,7 @@ def wrap_instantiated_class(
'{wrapped_ctors}'
'{wrapped_methods}'
'{wrapped_static_methods}'
'{wrapped_dunder_methods}'
'{wrapped_properties}'
'{wrapped_operators};\n'.format(
class_declaration=class_declaration,
Expand All @@ -406,6 +468,8 @@ def wrap_instantiated_class(
instantiated_class.methods, cpp_class),
wrapped_static_methods=self.wrap_methods(
instantiated_class.static_methods, cpp_class),
wrapped_dunder_methods=self.wrap_dunder_methods(
instantiated_class.dunder_methods, cpp_class),
wrapped_properties=self.wrap_properties(
instantiated_class.properties, cpp_class),
wrapped_operators=self.wrap_operators(
Expand Down
3 changes: 3 additions & 0 deletions gtwrap/template_instantiator/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''):

# Instantiate all instance methods
self.methods = self.instantiate_methods(typenames)

self.dunder_methods = original.dunder_methods

super().__init__(
self.template,
Expand All @@ -66,6 +68,7 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''):
self.ctors,
self.methods,
self.static_methods,
self.dunder_methods,
self.properties,
self.operators,
self.enums,
Expand Down
36 changes: 36 additions & 0 deletions tests/expected/matlab/FastSet.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
%class FastSet, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
%-------Constructors-------
%FastSet()
%
classdef FastSet < handle
properties
ptr_FastSet = 0
end
methods
function obj = FastSet(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
class_wrapper(73, my_ptr);
elseif nargin == 0
my_ptr = class_wrapper(74);
else
error('Arguments do not match any overload of FastSet constructor');
end
obj.ptr_FastSet = my_ptr;
end

function delete(obj)
class_wrapper(75, obj.ptr_FastSet);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
end
end
12 changes: 6 additions & 6 deletions tests/expected/matlab/MyFactorPosePoint2.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
function obj = MyFactorPosePoint2(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
class_wrapper(73, my_ptr);
class_wrapper(76, my_ptr);
elseif nargin == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'gtsam.noiseModel.Base')
my_ptr = class_wrapper(74, varargin{1}, varargin{2}, varargin{3}, varargin{4});
my_ptr = class_wrapper(77, varargin{1}, varargin{2}, varargin{3}, varargin{4});
else
error('Arguments do not match any overload of MyFactorPosePoint2 constructor');
end
obj.ptr_MyFactorPosePoint2 = my_ptr;
end

function delete(obj)
class_wrapper(75, obj.ptr_MyFactorPosePoint2);
class_wrapper(78, obj.ptr_MyFactorPosePoint2);
end

function display(obj), obj.print(''); end
Expand All @@ -36,19 +36,19 @@ function delete(obj)
% PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter')
class_wrapper(76, this, varargin{:});
class_wrapper(79, this, varargin{:});
return
end
% PRINT usage: print(string s) : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 1 && isa(varargin{1},'char')
class_wrapper(77, this, varargin{:});
class_wrapper(80, this, varargin{:});
return
end
% PRINT usage: print() : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 0
class_wrapper(78, this, varargin{:});
class_wrapper(81, this, varargin{:});
return
end
error('Arguments do not match any overload of function MyFactorPosePoint2.print');
Expand Down
Loading
Loading