diff --git a/gtsam/3rdparty/cephes/CMakeLists.txt b/gtsam/3rdparty/cephes/CMakeLists.txt index 9464481602..5940d39d2d 100644 --- a/gtsam/3rdparty/cephes/CMakeLists.txt +++ b/gtsam/3rdparty/cephes/CMakeLists.txt @@ -8,7 +8,6 @@ project( set(CEPHES_HEADER_FILES cephes.h - cephes/cephes_names.h cephes/dd_idefs.h cephes/dd_real.h cephes/dd_real_idefs.h diff --git a/gtsam/3rdparty/cephes/cephes.h b/gtsam/3rdparty/cephes/cephes.h index d5b59d895a..ed53e521bb 100644 --- a/gtsam/3rdparty/cephes/cephes.h +++ b/gtsam/3rdparty/cephes/cephes.h @@ -1,7 +1,6 @@ #ifndef CEPHES_H #define CEPHES_H -#include "cephes/cephes_names.h" #include "dllexport.h" #ifdef __cplusplus diff --git a/gtsam/3rdparty/cephes/cephes/cephes_names.h b/gtsam/3rdparty/cephes/cephes/cephes_names.h deleted file mode 100644 index 94be8c880a..0000000000 --- a/gtsam/3rdparty/cephes/cephes/cephes_names.h +++ /dev/null @@ -1,114 +0,0 @@ -#ifndef CEPHES_NAMES_H -#define CEPHES_NAMES_H - -#define airy cephes_airy -#define bdtrc cephes_bdtrc -#define bdtr cephes_bdtr -#define bdtri cephes_bdtri -#define besselpoly cephes_besselpoly -#define beta cephes_beta -#define lbeta cephes_lbeta -#define btdtr cephes_btdtr -#define cbrt cephes_cbrt -#define chdtrc cephes_chdtrc -#define chbevl cephes_chbevl -#define chdtr cephes_chdtr -#define chdtri cephes_chdtri -#define dawsn cephes_dawsn -#define ellie cephes_ellie -#define ellik cephes_ellik -#define ellpe cephes_ellpe -#define ellpj cephes_ellpj -#define ellpk cephes_ellpk -#define exp10 cephes_exp10 -#define exp2 cephes_exp2 -#define expn cephes_expn -#define fdtrc cephes_fdtrc -#define fdtr cephes_fdtr -#define fdtri cephes_fdtri -#define fresnl cephes_fresnl -#define Gamma cephes_Gamma -#define lgam cephes_lgam -#define lgam_sgn cephes_lgam_sgn -#define gammasgn cephes_gammasgn -#define gdtr cephes_gdtr -#define gdtrc cephes_gdtrc -#define gdtri cephes_gdtri -#define hyp2f1 cephes_hyp2f1 -#define hyperg cephes_hyperg -#define i0 cephes_i0 -#define i0e cephes_i0e -#define i1 cephes_i1 -#define i1e cephes_i1e -#define igamc cephes_igamc -#define igam cephes_igam -#define igami cephes_igami -#define incbet cephes_incbet -#define incbi cephes_incbi -#define iv cephes_iv -#define j0 cephes_j0 -#define y0 cephes_y0 -#define j1 cephes_j1 -#define y1 cephes_y1 -#define jn cephes_jn -#define jv cephes_jv -#define k0 cephes_k0 -#define k0e cephes_k0e -#define k1 cephes_k1 -#define k1e cephes_k1e -#define kn cephes_kn -#define nbdtrc cephes_nbdtrc -#define nbdtr cephes_nbdtr -#define nbdtri cephes_nbdtri -#define ndtr cephes_ndtr -#define erfc cephes_erfc -#define erf cephes_erf -#define erfinv cephes_erfinv -#define erfcinv cephes_erfcinv -#define ndtri cephes_ndtri -#define pdtrc cephes_pdtrc -#define pdtr cephes_pdtr -#define pdtri cephes_pdtri -#define poch cephes_poch -#define psi cephes_psi -#define rgamma cephes_rgamma -#define riemann_zeta cephes_riemann_zeta -// #define round cephes_round // Commented out since it clashes with std::round -#define shichi cephes_shichi -#define sici cephes_sici -#define radian cephes_radian -#define sindg cephes_sindg -#define sinpi cephes_sinpi -#define cosdg cephes_cosdg -#define cospi cephes_cospi -#define sincos cephes_sincos -#define spence cephes_spence -#define stdtr cephes_stdtr -#define stdtri cephes_stdtri -#define struve_h cephes_struve_h -#define struve_l cephes_struve_l -#define struve_power_series cephes_struve_power_series -#define struve_asymp_large_z cephes_struve_asymp_large_z -#define struve_bessel_series cephes_struve_bessel_series -#define yv cephes_yv -#define tandg cephes_tandg -#define cotdg cephes_cotdg -#define log1p cephes_log1p -#define expm1 cephes_expm1 -#define cosm1 cephes_cosm1 -#define yn cephes_yn -#define zeta cephes_zeta -#define zetac cephes_zetac -#define smirnov cephes_smirnov -#define smirnovc cephes_smirnovc -#define smirnovi cephes_smirnovi -#define smirnovci cephes_smirnovci -#define smirnovp cephes_smirnovp -#define kolmogorov cephes_kolmogorov -#define kolmogi cephes_kolmogi -#define kolmogp cephes_kolmogp -#define kolmogc cephes_kolmogc -#define kolmogci cephes_kolmogci -#define owens_t cephes_owens_t - -#endif diff --git a/gtsam/3rdparty/cephes/cephes/mconf.h b/gtsam/3rdparty/cephes/cephes/mconf.h index c59d17a470..5e971afadf 100644 --- a/gtsam/3rdparty/cephes/cephes/mconf.h +++ b/gtsam/3rdparty/cephes/cephes/mconf.h @@ -56,7 +56,6 @@ #include #include -#include "cephes_names.h" #include "cephes.h" #include "polevl.h" #include "sf_error.h" diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 834d5a1476..6d77e8eda0 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -39,6 +39,11 @@ class KeyList { void remove(size_t key); void serialize() const; + + // Special dunder methods for Python wrapping + __len__(); + __contains__(size_t key); + __iter__(); }; // Actually a FastSet @@ -64,6 +69,11 @@ class KeySet { bool count(size_t key) const; // returns true if value exists void serialize() const; + + // Special dunder methods for Python wrapping + __len__(); + __contains__(size_t key); + __iter__(); }; // Actually a vector, needed for Matlab @@ -85,6 +95,11 @@ class KeyVector { void push_back(size_t key) const; void serialize() const; + + // Special dunder methods for Python wrapping + __len__(); + __contains__(size_t key); + __iter__(); }; // Actually a FastMap diff --git a/python/gtsam/tests/test_Utilities.py b/python/gtsam/tests/test_Utilities.py index 851684f124..3dd472c75e 100644 --- a/python/gtsam/tests/test_Utilities.py +++ b/python/gtsam/tests/test_Utilities.py @@ -12,13 +12,14 @@ import unittest import numpy as np +from gtsam.utils.test_case import GtsamTestCase import gtsam -from gtsam.utils.test_case import GtsamTestCase class TestUtilites(GtsamTestCase): """Test various GTSAM utilities.""" + def test_createKeyList(self): """Test createKeyList.""" I = [0, 1, 2] @@ -28,6 +29,17 @@ def test_createKeyList(self): kl = gtsam.utilities.createKeyList("s", I) self.assertEqual(kl.size(), 3) + def test_KeyList_iteration(self): + """Tests for KeyList iteration""" + I = [0, 1, 2] + kl = gtsam.utilities.createKeyList(I) + + self.assertEqual(len(kl), len(I)) + + for i, key in enumerate(kl): + self.assertTrue(key in kl) + self.assertEqual(I[i], key) + def test_createKeyVector(self): """Test createKeyVector.""" I = [0, 1, 2] @@ -37,6 +49,17 @@ def test_createKeyVector(self): kl = gtsam.utilities.createKeyVector("s", I) self.assertEqual(len(kl), 3) + def test_KeyVector_iteration(self): + """Tests for KeyVector iteration""" + I = [0, 1, 2] + kv = gtsam.utilities.createKeyVector(I) + + self.assertEqual(len(kv), len(I)) + + for i, key in enumerate(kv): + self.assertTrue(key in kv) + self.assertEqual(I[i], key) + def test_createKeySet(self): """Test createKeySet.""" I = [0, 1, 2] @@ -46,6 +69,17 @@ def test_createKeySet(self): kl = gtsam.utilities.createKeySet("s", I) self.assertEqual(kl.size(), 3) + def test_KeySet_iteration(self): + """Tests for KeySet iteration""" + I = [0, 1, 2] + ks = gtsam.utilities.createKeySet(I) + + self.assertEqual(len(ks), len(I)) + + for i, key in enumerate(ks): + self.assertTrue(key in ks) + self.assertEqual(I[i], key) + def test_extractPoint2(self): """Test extractPoint2.""" initial = gtsam.Values() diff --git a/wrap/gtwrap/interface_parser/classes.py b/wrap/gtwrap/interface_parser/classes.py index b63a0b5eb2..8967bea93b 100644 --- a/wrap/gtwrap/interface_parser/classes.py +++ b/wrap/gtwrap/interface_parser/classes.py @@ -12,13 +12,14 @@ from typing import Any, Iterable, List, Union -from pyparsing import Literal, Optional, ZeroOrMore # type: ignore +from pyparsing import ZeroOrMore # type: ignore +from pyparsing import Literal, Optional, Word, alphas from .enum import Enum from .function import ArgumentList, ReturnType from .template import Template -from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, OPERATOR, - RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL) +from .tokens import (CLASS, COLON, CONST, DUNDER, IDENT, LBRACE, LPAREN, + OPERATOR, RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL) from .type import TemplatedType, Typename from .utils import collect_namespaces from .variable import Variable @@ -212,6 +213,26 @@ def __repr__(self) -> str: ) +class DunderMethod: + """Special Python double-underscore (dunder) methods, e.g. __iter__, __contains__""" + rule = ( + DUNDER # + + (Word(alphas))("name") # + + DUNDER # + + LPAREN # + + ArgumentList.rule("args_list") # + + RPAREN # + + SEMI_COLON # BR + ).setParseAction(lambda t: DunderMethod(t.name, t.args_list)) + + def __init__(self, name: str, args: ArgumentList): + self.name = name + self.args = args + + def __repr__(self) -> str: + return f"DunderMethod: __{self.name}__({self.args})" + + class Class: """ Rule to parse a class defined in the interface file. @@ -223,11 +244,13 @@ class Hello { }; ``` """ + class Members: """ Rule for all the members within a class. """ - rule = ZeroOrMore(Constructor.rule # + rule = ZeroOrMore(DunderMethod.rule # + ^ Constructor.rule # ^ Method.rule # ^ StaticMethod.rule # ^ Variable.rule # @@ -235,11 +258,12 @@ class Members: ^ Enum.rule # ).setParseAction(lambda t: Class.Members(t.asList())) - def __init__(self, - members: List[Union[Constructor, Method, StaticMethod, - Variable, Operator]]): + def __init__(self, members: List[Union[Constructor, Method, + StaticMethod, Variable, + Operator, Enum, DunderMethod]]): self.ctors = [] self.methods = [] + self.dunder_methods = [] self.static_methods = [] self.properties = [] self.operators = [] @@ -251,6 +275,8 @@ def __init__(self, self.methods.append(m) elif isinstance(m, StaticMethod): self.static_methods.append(m) + elif isinstance(m, DunderMethod): + self.dunder_methods.append(m) elif isinstance(m, Variable): self.properties.append(m) elif isinstance(m, Operator): @@ -271,8 +297,8 @@ def __init__(self, + SEMI_COLON # BR ).setParseAction(lambda t: Class( t.template, t.is_virtual, t.name, t.parent_class, t.members.ctors, t. - members.methods, t.members.static_methods, t.members.properties, t. - members.operators, t.members.enums)) + members.methods, t.members.static_methods, t.members.dunder_methods, t. + members.properties, t.members.operators, t.members.enums)) def __init__( self, @@ -283,6 +309,7 @@ def __init__( ctors: List[Constructor], methods: List[Method], static_methods: List[StaticMethod], + dunder_methods: List[DunderMethod], properties: List[Variable], operators: List[Operator], enums: List[Enum], @@ -308,6 +335,7 @@ def __init__( self.ctors = ctors self.methods = methods self.static_methods = static_methods + self.dunder_methods = dunder_methods self.properties = properties self.operators = operators self.enums = enums @@ -326,6 +354,8 @@ def __init__( method.parent = self for static_method in self.static_methods: static_method.parent = self + for dunder_method in self.dunder_methods: + dunder_method.parent = self for _property in self.properties: _property.parent = self diff --git a/wrap/gtwrap/interface_parser/function.py b/wrap/gtwrap/interface_parser/function.py index b408844886..5385c744f1 100644 --- a/wrap/gtwrap/interface_parser/function.py +++ b/wrap/gtwrap/interface_parser/function.py @@ -82,7 +82,7 @@ def from_parse_result(parse_result: ParseResults): return ArgumentList([]) def __repr__(self) -> str: - return repr(tuple(self.args_list)) + return ",".join([repr(x) for x in self.args_list]) def __len__(self) -> int: return len(self.args_list) diff --git a/wrap/gtwrap/interface_parser/tokens.py b/wrap/gtwrap/interface_parser/tokens.py index 02e6d82f84..11c99d19c2 100644 --- a/wrap/gtwrap/interface_parser/tokens.py +++ b/wrap/gtwrap/interface_parser/tokens.py @@ -22,6 +22,7 @@ LPAREN, RPAREN, LBRACE, RBRACE, COLON, SEMI_COLON = map(Suppress, "(){}:;") LOPBRACK, ROPBRACK, COMMA, EQUAL = map(Suppress, "<>,=") +DUNDER = Suppress(Literal("__")) # Default argument passed to functions/methods. # Allow anything up to ',' or ';' except when they diff --git a/wrap/gtwrap/pybind_wrapper.py b/wrap/gtwrap/pybind_wrapper.py index 78730a909f..479c2d67d4 100755 --- a/wrap/gtwrap/pybind_wrapper.py +++ b/wrap/gtwrap/pybind_wrapper.py @@ -45,6 +45,8 @@ def __init__(self, 'continue', 'global', 'pass' ] + self.dunder_methods = ('len', 'contains', 'iter') + # amount of indentation to add before each function/method declaration. self.method_indent = '\n' + (' ' * 8) @@ -153,6 +155,51 @@ def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str, suffix=suffix) return ret + def _wrap_dunder(self, + method, + cpp_class, + prefix, + suffix, + method_suffix=""): + """ + Wrap a Python double-underscore (dunder) method. + + E.g. __len__() gets wrapped as `.def("__len__", [](gtsam::KeySet* self) {return self->size();})` + + Supported methods are: + - __contains__(T x) + - __len__() + - __iter__() + """ + py_method = method.name + method_suffix + args_names = method.args.names() + py_args_names = self._py_args_names(method.args) + args_signature_with_names = self._method_args_signature(method.args) + + if method.name == 'len': + function_call = "return std::distance(self->begin(), self->end());" + elif method.name == 'contains': + function_call = f"return std::find(self->begin(), self->end(), {method.args.args_list[0].name}) != self->end();" + elif method.name == 'iter': + function_call = "return py::make_iterator(self->begin(), self->end());" + + ret = ('{prefix}.def("__{py_method}__",' + '[]({self}{opt_comma}{args_signature_with_names}){{' + '{function_call}' + '}}' + '{py_args_names}){suffix}'.format( + prefix=prefix, + py_method=py_method, + self=f"{cpp_class}* self", + opt_comma=', ' if args_names else '', + args_signature_with_names=args_signature_with_names, + function_call=function_call, + py_args_names=py_args_names, + suffix=suffix, + )) + + return ret + def _wrap_method(self, method, cpp_class, @@ -235,6 +282,20 @@ def _wrap_method(self, return ret + def wrap_dunder_methods(self, + methods, + cpp_class, + prefix='\n' + ' ' * 8, + suffix=''): + res = "" + for method in methods: + res += self._wrap_dunder(method=method, + cpp_class=cpp_class, + prefix=prefix, + suffix=suffix) + + return res + def wrap_methods(self, methods, cpp_class, @@ -398,6 +459,7 @@ def wrap_instantiated_class( '{wrapped_ctors}' '{wrapped_methods}' '{wrapped_static_methods}' + '{wrapped_dunder_methods}' '{wrapped_properties}' '{wrapped_operators};\n'.format( class_declaration=class_declaration, @@ -406,6 +468,8 @@ def wrap_instantiated_class( instantiated_class.methods, cpp_class), wrapped_static_methods=self.wrap_methods( instantiated_class.static_methods, cpp_class), + wrapped_dunder_methods=self.wrap_dunder_methods( + instantiated_class.dunder_methods, cpp_class), wrapped_properties=self.wrap_properties( instantiated_class.properties, cpp_class), wrapped_operators=self.wrap_operators( diff --git a/wrap/gtwrap/template_instantiator/classes.py b/wrap/gtwrap/template_instantiator/classes.py index ce51d5b967..7026546785 100644 --- a/wrap/gtwrap/template_instantiator/classes.py +++ b/wrap/gtwrap/template_instantiator/classes.py @@ -57,6 +57,8 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''): # Instantiate all instance methods self.methods = self.instantiate_methods(typenames) + + self.dunder_methods = original.dunder_methods super().__init__( self.template, @@ -66,6 +68,7 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''): self.ctors, self.methods, self.static_methods, + self.dunder_methods, self.properties, self.operators, self.enums, diff --git a/wrap/requirements.txt b/wrap/requirements.txt index 0aac9302e5..f43fdda617 100644 --- a/wrap/requirements.txt +++ b/wrap/requirements.txt @@ -1,2 +1,2 @@ -pyparsing==2.4.7 +pyparsing==3.1.1 pytest>=6.2.4 diff --git a/wrap/tests/expected/matlab/FastSet.m b/wrap/tests/expected/matlab/FastSet.m new file mode 100644 index 0000000000..4d2a1813e8 --- /dev/null +++ b/wrap/tests/expected/matlab/FastSet.m @@ -0,0 +1,36 @@ +%class FastSet, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%FastSet() +% +classdef FastSet < handle + properties + ptr_FastSet = 0 + end + methods + function obj = FastSet(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + class_wrapper(73, my_ptr); + elseif nargin == 0 + my_ptr = class_wrapper(74); + else + error('Arguments do not match any overload of FastSet constructor'); + end + obj.ptr_FastSet = my_ptr; + end + + function delete(obj) + class_wrapper(75, obj.ptr_FastSet); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/MyFactorPosePoint2.m b/wrap/tests/expected/matlab/MyFactorPosePoint2.m index 4a30bd4894..ac5b134f9b 100644 --- a/wrap/tests/expected/matlab/MyFactorPosePoint2.m +++ b/wrap/tests/expected/matlab/MyFactorPosePoint2.m @@ -15,9 +15,9 @@ function obj = MyFactorPosePoint2(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(73, my_ptr); + class_wrapper(76, my_ptr); elseif nargin == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'gtsam.noiseModel.Base') - my_ptr = class_wrapper(74, varargin{1}, varargin{2}, varargin{3}, varargin{4}); + my_ptr = class_wrapper(77, varargin{1}, varargin{2}, varargin{3}, varargin{4}); else error('Arguments do not match any overload of MyFactorPosePoint2 constructor'); end @@ -25,7 +25,7 @@ end function delete(obj) - class_wrapper(75, obj.ptr_MyFactorPosePoint2); + class_wrapper(78, obj.ptr_MyFactorPosePoint2); end function display(obj), obj.print(''); end @@ -36,19 +36,19 @@ function delete(obj) % PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter') - class_wrapper(76, this, varargin{:}); + class_wrapper(79, this, varargin{:}); return end % PRINT usage: print(string s) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'char') - class_wrapper(77, this, varargin{:}); + class_wrapper(80, this, varargin{:}); return end % PRINT usage: print() : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - class_wrapper(78, this, varargin{:}); + class_wrapper(81, this, varargin{:}); return end error('Arguments do not match any overload of function MyFactorPosePoint2.print'); diff --git a/wrap/tests/expected/matlab/class_wrapper.cpp b/wrap/tests/expected/matlab/class_wrapper.cpp index c4be52018c..e33f14238d 100644 --- a/wrap/tests/expected/matlab/class_wrapper.cpp +++ b/wrap/tests/expected/matlab/class_wrapper.cpp @@ -31,6 +31,8 @@ typedef std::set*> Collector_ForwardKinematic static Collector_ForwardKinematics collector_ForwardKinematics; typedef std::set*> Collector_TemplatedConstructor; static Collector_TemplatedConstructor collector_TemplatedConstructor; +typedef std::set*> Collector_FastSet; +static Collector_FastSet collector_FastSet; typedef std::set*> Collector_MyFactorPosePoint2; static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; @@ -101,6 +103,12 @@ void _deleteAllObjects() collector_TemplatedConstructor.erase(iter++); anyDeleted = true; } } + { for(Collector_FastSet::iterator iter = collector_FastSet.begin(); + iter != collector_FastSet.end(); ) { + delete *iter; + collector_FastSet.erase(iter++); + anyDeleted = true; + } } { for(Collector_MyFactorPosePoint2::iterator iter = collector_MyFactorPosePoint2.begin(); iter != collector_MyFactorPosePoint2.end(); ) { delete *iter; @@ -844,7 +852,40 @@ void TemplatedConstructor_deconstructor_72(int nargout, mxArray *out[], int narg delete self; } -void MyFactorPosePoint2_collectorInsertAndMakeBase_73(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FastSet_collectorInsertAndMakeBase_73(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_FastSet.insert(self); +} + +void FastSet_constructor_74(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = new Shared(new FastSet()); + collector_FastSet.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void FastSet_deconstructor_75(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_FastSet",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_FastSet::iterator item; + item = collector_FastSet.find(self); + if(item != collector_FastSet.end()) { + collector_FastSet.erase(item); + } + delete self; +} + +void MyFactorPosePoint2_collectorInsertAndMakeBase_76(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef std::shared_ptr> Shared; @@ -853,7 +894,7 @@ void MyFactorPosePoint2_collectorInsertAndMakeBase_73(int nargout, mxArray *out[ collector_MyFactorPosePoint2.insert(self); } -void MyFactorPosePoint2_constructor_74(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_constructor_77(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef std::shared_ptr> Shared; @@ -868,7 +909,7 @@ void MyFactorPosePoint2_constructor_74(int nargout, mxArray *out[], int nargin, *reinterpret_cast (mxGetData(out[0])) = self; } -void MyFactorPosePoint2_deconstructor_75(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_deconstructor_78(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef std::shared_ptr> Shared; checkArguments("delete_MyFactorPosePoint2",nargout,nargin,1); @@ -881,7 +922,7 @@ void MyFactorPosePoint2_deconstructor_75(int nargout, mxArray *out[], int nargin delete self; } -void MyFactorPosePoint2_print_76(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_print_79(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,2); auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); @@ -890,7 +931,7 @@ void MyFactorPosePoint2_print_76(int nargout, mxArray *out[], int nargin, const obj->print(s,keyFormatter); } -void MyFactorPosePoint2_print_77(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_print_80(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,1); auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); @@ -898,7 +939,7 @@ void MyFactorPosePoint2_print_77(int nargout, mxArray *out[], int nargin, const obj->print(s,gtsam::DefaultKeyFormatter); } -void MyFactorPosePoint2_print_78(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_print_81(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,0); auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); @@ -1137,22 +1178,31 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) TemplatedConstructor_deconstructor_72(nargout, out, nargin-1, in+1); break; case 73: - MyFactorPosePoint2_collectorInsertAndMakeBase_73(nargout, out, nargin-1, in+1); + FastSet_collectorInsertAndMakeBase_73(nargout, out, nargin-1, in+1); break; case 74: - MyFactorPosePoint2_constructor_74(nargout, out, nargin-1, in+1); + FastSet_constructor_74(nargout, out, nargin-1, in+1); break; case 75: - MyFactorPosePoint2_deconstructor_75(nargout, out, nargin-1, in+1); + FastSet_deconstructor_75(nargout, out, nargin-1, in+1); break; case 76: - MyFactorPosePoint2_print_76(nargout, out, nargin-1, in+1); + MyFactorPosePoint2_collectorInsertAndMakeBase_76(nargout, out, nargin-1, in+1); break; case 77: - MyFactorPosePoint2_print_77(nargout, out, nargin-1, in+1); + MyFactorPosePoint2_constructor_77(nargout, out, nargin-1, in+1); break; case 78: - MyFactorPosePoint2_print_78(nargout, out, nargin-1, in+1); + MyFactorPosePoint2_deconstructor_78(nargout, out, nargin-1, in+1); + break; + case 79: + MyFactorPosePoint2_print_79(nargout, out, nargin-1, in+1); + break; + case 80: + MyFactorPosePoint2_print_80(nargout, out, nargin-1, in+1); + break; + case 81: + MyFactorPosePoint2_print_81(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/wrap/tests/expected/python/class_pybind.cpp b/wrap/tests/expected/python/class_pybind.cpp index 86d69c2e0e..2292f46be0 100644 --- a/wrap/tests/expected/python/class_pybind.cpp +++ b/wrap/tests/expected/python/class_pybind.cpp @@ -91,6 +91,12 @@ PYBIND11_MODULE(class_py, m_) { .def(py::init(), py::arg("arg")) .def(py::init(), py::arg("arg")); + py::class_>(m_, "FastSet") + .def(py::init<>()) + .def("__len__",[](FastSet* self){return std::distance(self->begin(), self->end());}) + .def("__contains__",[](FastSet* self, size_t key){return std::find(self->begin(), self->end(), key) != self->end();}, py::arg("key")) + .def("__iter__",[](FastSet* self){return py::make_iterator(self->begin(), self->end());}); + py::class_, std::shared_ptr>>(m_, "MyFactorPosePoint2") .def(py::init>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel")) .def("print",[](MyFactor* self, const string& s, const gtsam::KeyFormatter& keyFormatter){ py::scoped_ostream_redirect output; self->print(s, keyFormatter);}, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter) diff --git a/wrap/tests/fixtures/class.i b/wrap/tests/fixtures/class.i index 766f55329a..775bbc737c 100644 --- a/wrap/tests/fixtures/class.i +++ b/wrap/tests/fixtures/class.i @@ -145,3 +145,12 @@ class TemplatedConstructor { class SuperCoolFactor; typedef SuperCoolFactor SuperCoolFactorPose3; + +/// @brief class with dunder methods for container behavior +class FastSet { + FastSet(); + + __len__(); + __contains__(size_t key); + __iter__(); +}; \ No newline at end of file diff --git a/wrap/tests/test_interface_parser.py b/wrap/tests/test_interface_parser.py index 45415995fe..2a923b3c5f 100644 --- a/wrap/tests/test_interface_parser.py +++ b/wrap/tests/test_interface_parser.py @@ -18,11 +18,12 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from gtwrap.interface_parser import (ArgumentList, Class, Constructor, Enum, - Enumerator, ForwardDeclaration, - GlobalFunction, Include, Method, Module, - Namespace, Operator, ReturnType, - StaticMethod, TemplatedType, Type, +from gtwrap.interface_parser import (ArgumentList, Class, Constructor, + DunderMethod, Enum, Enumerator, + ForwardDeclaration, GlobalFunction, + Include, Method, Module, Namespace, + Operator, ReturnType, StaticMethod, + TemplatedType, Type, TypedefTemplateInstantiation, Typename, Variable) from gtwrap.template_instantiator.classes import InstantiatedClass @@ -344,6 +345,17 @@ def test_constructor_templated(self): self.assertEqual(1, len(ret.args)) self.assertEqual("const T & name", ret.args.args_list[0].to_cpp()) + def test_dunder_method(self): + """Test for special python dunder methods.""" + iter_string = "__iter__();" + ret = DunderMethod.rule.parse_string(iter_string)[0] + self.assertEqual("iter", ret.name) + + contains_string = "__contains__(size_t key);" + ret = DunderMethod.rule.parse_string(contains_string)[0] + self.assertEqual("contains", ret.name) + self.assertTrue(len(ret.args) == 1) + def test_operator_overload(self): """Test for operator overloading.""" # Unary operator