Skip to content

Commit f34a039

Browse files
committed
Improve performance of enum_ operators by going back to specific implementation
test_enum needs a patch because ops are now overloaded and this affects their docstrings.
1 parent e6984c8 commit f34a039

File tree

3 files changed

+74
-27
lines changed

3 files changed

+74
-27
lines changed

include/pybind11/detail/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@
163163
# define PYBIND11_NOINLINE __attribute__((noinline)) inline
164164
#endif
165165

166+
#if defined(_MSC_VER)
167+
# define PYBIND11_ALWAYS_INLINE __forceinline
168+
#elif defined(__GNUC__)
169+
# define PYBIND11_ALWAYS_INLINE __attribute__((__always_inline__)) inline
170+
#else
171+
# define PYBIND11_ALWAYS_INLINE inline
172+
#endif
173+
166174
#if defined(__MINGW32__)
167175
// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared
168176
// whether it is used or not

include/pybind11/pybind11.h

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,7 +2218,7 @@ class class_ : public detail::generic_type {
22182218
static void add_base(detail::type_record &) {}
22192219

22202220
template <typename Func, typename... Extra>
2221-
class_ &def(const char *name_, Func &&f, const Extra &...extra) {
2221+
PYBIND11_ALWAYS_INLINE class_ &def(const char *name_, Func &&f, const Extra &...extra) {
22222222
cpp_function cf(method_adaptor<type>(std::forward<Func>(f)),
22232223
name(name_),
22242224
is_method(*this),
@@ -2797,38 +2797,13 @@ struct enum_base {
27972797
pos_only())
27982798

27992799
if (is_convertible) {
2800-
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
2801-
PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b));
2802-
28032800
if (is_arithmetic) {
2804-
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
2805-
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
2806-
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
2807-
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
2808-
PYBIND11_ENUM_OP_CONV("__and__", a & b);
2809-
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
2810-
PYBIND11_ENUM_OP_CONV("__or__", a | b);
2811-
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
2812-
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
2813-
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
28142801
m_base.attr("__invert__")
28152802
= cpp_function([](const object &arg) { return ~(int_(arg)); },
28162803
name("__invert__"),
28172804
is_method(m_base),
28182805
pos_only());
28192806
}
2820-
} else {
2821-
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
2822-
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
2823-
2824-
if (is_arithmetic) {
2825-
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
2826-
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW);
2827-
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW);
2828-
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW);
2829-
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW);
2830-
#undef PYBIND11_THROW
2831-
}
28322807
}
28332808

28342809
#undef PYBIND11_ENUM_OP_CONV_LHS
@@ -2944,6 +2919,59 @@ class enum_ : public class_<Type> {
29442919

29452920
def(init([](Scalar i) { return static_cast<Type>(i); }), arg("value"));
29462921
def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only());
2922+
#define PYBIND11_ENUM_OP_SAME_TYPE(op, expr) \
2923+
def(op, [](Type a, Type b) { return expr; }, pybind11::name(op), arg("other"), pos_only())
2924+
#define PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE(op, expr) \
2925+
def(op, [](Type a, Type *b_ptr) { return expr; }, pybind11::name(op), arg("other"), pos_only())
2926+
#define PYBIND11_ENUM_OP_SCALAR(op, op_expr) \
2927+
def( \
2928+
op, \
2929+
[](Type a, Scalar b) { return static_cast<Scalar>(a) op_expr b; }, \
2930+
pybind11::name(op), \
2931+
arg("other"), \
2932+
pos_only())
2933+
#define PYBIND11_ENUM_OP_CONV_ARITHMETIC(op, op_expr) \
2934+
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
2935+
PYBIND11_ENUM_OP_SCALAR(op, op_expr)
2936+
#define PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior) \
2937+
def( \
2938+
op, \
2939+
[](Type, const object &) { strict_behavior; }, \
2940+
pybind11::name(op), \
2941+
arg("other"), \
2942+
pos_only())
2943+
#define PYBIND11_ENUM_OP_STRICT_ARITHMETIC(op, op_expr, strict_behavior) \
2944+
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
2945+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior);
2946+
2947+
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__eq__", b_ptr && a == *b_ptr);
2948+
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__ne__", !b_ptr || a != *b_ptr);
2949+
if (std::is_convertible<Type, Scalar>::value) {
2950+
PYBIND11_ENUM_OP_SCALAR("__eq__", ==);
2951+
PYBIND11_ENUM_OP_SCALAR("__ne__", !=);
2952+
if (is_arithmetic) {
2953+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__lt__", <);
2954+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__gt__", >);
2955+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__le__", <=);
2956+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ge__", >=);
2957+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__and__", &);
2958+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rand__", &);
2959+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__or__", |);
2960+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ror__", |);
2961+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__xor__", ^);
2962+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rxor__", ^);
2963+
}
2964+
} else if (is_arithmetic) {
2965+
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
2966+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__lt__", <, PYBIND11_THROW);
2967+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__gt__", >, PYBIND11_THROW);
2968+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__le__", <=, PYBIND11_THROW);
2969+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__ge__", >=, PYBIND11_THROW);
2970+
#undef PYBIND11_THROW
2971+
}
2972+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__eq__", return false);
2973+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__ne__", return true);
2974+
29472975
def("__int__", [](Type value) { return (Scalar) value; }, pos_only());
29482976
def("__index__", [](Type value) { return (Scalar) value; }, pos_only());
29492977
attr("__setstate__") = cpp_function(

tests/test_enum.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,21 @@ def test_generated_dunder_methods_pos_only():
295295
"__rxor__",
296296
]:
297297
method = getattr(enum_type, binary_op, None)
298+
# TODO: docs now show overloading; update
298299
if method is not None:
300+
# 1) The docs must start with the name of the op.
299301
assert (
300302
re.match(
301-
rf"^{binary_op}\(self: [\w\.]+, other: [\w\.]+, /\)",
303+
rf"^{binary_op}\(",
304+
method.__doc__,
305+
)
306+
is not None
307+
)
308+
# 2) The docs must contain the op's signature. This is a separate check
309+
# and not anchored at the start because the op may be overloaded.
310+
assert (
311+
re.search(
312+
rf"{binary_op}\(self: [\w\.]+, other: [\w\.]+, /\)",
302313
method.__doc__,
303314
)
304315
is not None

0 commit comments

Comments
 (0)