Unable to wrap a function that's called with an enum #5492
-
I find myself in need of wrapping a heap of functions in a generic wrapper (think "take a lock before calling this API"). This works quite nicely, except when a parameter of the wrapped function is an enum. Hints how to fix this (I'm not that good a C++ programmer) would be appreciated.
Using the command line
the result is this error:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi @smurfix, I have been looking at this issue. This seems to be caused by the fact that when compiling the wrapped function, the compiler only generates the template instance I've attempted multiple solutions (see below), but most of them ultimately require to provide pybind11 with a non-templated function or My main takeaway here is that there isn't really any point in forcing a pure C++ behaviour (perfect forwarding) over pybind11. As mentioned here
Data is copied at the boundary between Python and C++. So, the natural solution would be to define a non-templated wrapper function to be used by pybind11, while keeping the templated wrapper class #include <pybind11/pybind11.h>
#include <iostream>
namespace py = pybind11;
template <class R, class... Args>
struct wrap {
using funct_type = R (*)(Args...);
funct_type func;
wrap(funct_type f) : func(f) {}
// Templated operator() to be used in C++ code
R operator()(Args&&... args) {
// Perfect forwarding
std::cout << "before calling\n";
R ret = func(std::forward<Args>(args)...);
std::cout << "after calling\n";
return ret;
}
};
enum xfoo {
FOO = 1,
BAR = 2,
BAZ = 3,
};
// Example function we want to wrap
int add(int a, int b, enum xfoo c) {
return a + b + c;
}
// Non-templated wrapper function for pybind11
int add_wrapped(int a, int b, enum xfoo c) {
wrap w{&add};
return w(a, b, c);
}
PYBIND11_MODULE(example, m) {
py::enum_<xfoo>(m, "xfoo")
.value("FOO", xfoo::FOO)
.value("BAR", xfoo::BAR)
.value("BAZ", xfoo::BAZ)
.export_values();
m.def("add", &add_wrapped);
/*
Alternative: use a lambda
m.def("add", [](int a, int b, enum xfoo c) {
wrap w{&add};
return w(a, b, c);
});
*/
} You'll have to define an helper function like I'll provide a few alternatives anyway (among the ones I've attempted), though I'm not sure if any of them is better than the one above. Possible Alternatives1. Modify the add functionOne quick solution would be to modify the add function to expect the enum as an lvalue-reference instead of just an lvalue int add(int a, int b, enum xfoo& c) {
return a + b + c;
} This, due to the fact that 2. Drop the universal reference when encountering enumsOne solution would be to replace the #include <iostream>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <type_traits>
namespace py = pybind11;
// Helper types to detect enums in parameter pack
template <typename T>
constexpr bool is_enum_v = std::is_enum_v<std::remove_reference_t<T>>;
template <typename... Args>
constexpr bool has_enum_v = (is_enum_v<Args> || ...);
enum xfoo {
FOO = 1,
BAR = 2,
BAZ = 3,
};
// Generic template
template <class R, class... Args>
struct wrap
{
using funct_type = R (*)(Args...);
funct_type func;
wrap(funct_type f) : func(f) {}
R operator()(Args&&... args)
{
// Before calling
std::cout << "before calling\n";
R ret = func(std::forward<Args>(args)...);
// After calling
std::cout << "After calling\n";
return ret;
}
};
// Specialization for when Args contains enums
template <class R, class... Args>
struct wrap_with_enum
{
using funct_type = R (*)(Args...);
funct_type func;
wrap_with_enum(funct_type f) : func(f) {}
R operator()(Args... args)
{
// Before calling
std::cout << "before calling (specialized for enums)\n";
R ret = func(args...); // Call without forwarding
// After calling
std::cout << "After calling (specialized for enums)\n";
return ret;
}
};
// Factory to choose the correct wrap implementation
template <class R, class... Args>
using wrap_selector = std::conditional_t<
has_enum_v<Args...>,
wrap_with_enum<R, Args...>,
wrap<R, Args...>>;
int add(int a, int b) {
return a + b;
}
int add_with_enum(int a, int b, enum xfoo c) {
return a + b + c;
}
PYBIND11_MODULE(example, m) {
m.doc() = "pybind11 example plugin"; // Optional module docstring
py::enum_<xfoo>(m, "xfoo")
.value("FOO", xfoo::FOO)
.value("BAR", xfoo::BAR)
.value("BAZ", xfoo::BAZ)
.export_values();
// Use wrap_selector to choose the appropriate implementation
m.def("add", wrap_selector<int, int, int>{add}, "A function that adds two numbers");
m.def("add_with_enum", wrap_selector<int, int, int, xfoo>{add_with_enum}, "A function that adds two numbers and an enum");
} This solution uses import example
print(example.add(1, 2, 3))
print(example.add_with_enum(1, 2, example.xfoo.FOO))
# Output:
# before calling
# After calling
# 3
# before calling (specialized for enums)
# After calling (specialized for enums)
# 4 3. Partial specialization/partial forwardingThis is a bit redundant, but I'll include it here for completeness. Another route I attempted was to rewrite the wrapper class so that it only forwards non enum types while enums are passed by value. This would allow to use the wrapper class in C++ and Python code as-is. I was able to came up with the following code: template <typename R, typename... Ts>
struct partial_forward_wrapper {
using funct_type = R (*)(Ts...);
funct_type func;
explicit partial_forward_wrapper(funct_type f) : func(f) {}
template <typename... UArgs>
R operator()(UArgs&&... uargs) {
static_assert(sizeof...(UArgs) == sizeof...(Ts),
"Mismatched argument count");
std::cout << "[before calling wrapped function]\n";
R ret = call_impl(std::index_sequence_for<Ts...>{}, std::forward<UArgs>(uargs)...);
std::cout << "[after calling wrapped function]\n";
return ret;
}
private:
// Expand each argument with a helper that either perfect-forwards
// or passes-by-value if it's an enum.
template <std::size_t... I, typename... UArgs>
R call_impl(std::index_sequence<I...>, UArgs&&... uargs) {
// Call 'func' with each argument processed by 'select_arg'
return func(
select_arg<std::tuple_element_t<I, std::tuple<Ts...>>>(
std::forward<UArgs>(uargs)
)...
);
}
// If T is an enum, cast & pass by value
template <typename T, typename U>
auto select_arg(U&& u)
-> std::enable_if_t<std::is_enum_v<T>, T>
{
return static_cast<T>(u);
}
// Otherwise, perfect-forward
template <typename T, typename U>
auto select_arg(U&& u)
-> std::enable_if_t<!std::is_enum_v<T>, T&&>
{
return std::forward<U>(u);
}
}; This works in C++, but unfortunately doesn't work with pybind11: the issue here is that A possible solution here would be to provide a non-templated function for pybind11 use, as done in the very first solution with |
Beta Was this translation helpful? Give feedback.
Hi @smurfix,
I have been looking at this issue. This seems to be caused by the fact that when compiling the wrapped function, the compiler only generates the template instance
wrap<int, int, int, xfoo>(int&&, int&&, xfoo&&)
which is what you would expect in C++ by the resolution of the universal reference. However, pybind11 tries to pass the enums as lvalue so it looks forwrap<int, int, int, xfoo>(int, int, xfoo&)
, which hasn't been instantiated by the compiler.I've attempted multiple solutions (see below), but most of them ultimately require to provide pybind11 with a non-templated function or
operator()
so that it can correctly infer its signature withdecltype(&F::operator())
, which …