Skip to content

Commit

Permalink
Make protobuf python_api fully optional
Browse files Browse the repository at this point in the history
In case `PYBIND11_PROTOBUF_ENABLE_PYPROTO_API` is not defined, the
py_proto_api_ member of the GlobalState singleton is never changed from
its default nullptr value.

Any code protected by a `GlobalState::instance()->py_proto_api()` check
can thus also be made dependent on the `PYPROTO_API` define. This allows
to remove the dependency on the proto_api.h header file.

As the call to check_unknown_fields::CheckRecursively is also protected
by the `py_proto_api()` it can be stubbed out.

See pybind#127.
  • Loading branch information
StefanBruens committed Jun 14, 2024
1 parent af8ee51 commit 29997d6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
4 changes: 4 additions & 0 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ std::string MakeAllowListKey(
return absl::StrCat(top_message_descriptor_full_name, ":",
unknown_field_parent_message_fqn);
}
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)

/// Recurses through the message Descriptor class looking for valid extensions.
/// Stores the result to `memoized`.
Expand Down Expand Up @@ -173,6 +174,7 @@ std::string HasUnknownFields::BuildErrorMessage() const {
return emsg;
}

#endif
} // namespace

void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
Expand All @@ -181,6 +183,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
unknown_field_parent_message_fqn));
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
Expand All @@ -195,5 +198,6 @@ std::optional<std::string> CheckRecursively(
}
return search.BuildErrorMessage();
}
#endif

} // namespace pybind11_protobuf::check_unknown_fields
9 changes: 7 additions & 2 deletions pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

#include <optional>

#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
#include "python/google/protobuf/proto_api.h"
#include "absl/strings/string_view.h"
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API

namespace pybind11_protobuf::check_unknown_fields {

Expand Down Expand Up @@ -45,9 +48,11 @@ class ExtensionsWithUnknownFieldsPolicy {
void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn);

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message);
const ::google::protobuf::Message* top_message, bool build_error_message_if_any);
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
18 changes: 17 additions & 1 deletion pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/dynamic_message.h"
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
#include "python/google/protobuf/proto_api.h"
#else
namespace google::protobuf::python {
struct PyProto_API;
}
#endif
#include "pybind11_protobuf/check_unknown_fields.h"

#if defined(GOOGLE_PROTOBUF_VERSION)
Expand All @@ -46,7 +52,6 @@ using ::google::protobuf::FileDescriptorProto;
using ::google::protobuf::Message;
using ::google::protobuf::MessageFactory;
using ::google::protobuf::python::PyProto_API;
using ::google::protobuf::python::PyProtoAPICapsuleName;

namespace pybind11_protobuf {

Expand Down Expand Up @@ -266,6 +271,7 @@ GlobalState::GlobalState() {
//
// By default (3) is used, however if the define is set *and* the version
// matches, then pybind11_protobuf will assume that this will work.
using ::google::protobuf::python::PyProtoAPICapsuleName;
py_proto_api_ =
static_cast<PyProto_API*>(PyCapsule_Import(PyProtoAPICapsuleName(), 0));
if (py_proto_api_ == nullptr) {
Expand Down Expand Up @@ -355,6 +361,7 @@ py::object GlobalState::PyMessageInstance(const Descriptor* descriptor) {
module_name + "?");
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::pair<py::object, Message*> GlobalState::PyFastCppProtoMessageInstance(
const Descriptor* descriptor) {
assert(descriptor != nullptr);
Expand Down Expand Up @@ -395,6 +402,7 @@ std::pair<py::object, Message*> GlobalState::PyFastCppProtoMessageInstance(
}
return {std::move(result), message};
}
#endif

// Create C++ DescriptorPools based on Python DescriptorPools.
// The Python pool will provide message definitions when they are needed.
Expand Down Expand Up @@ -534,6 +542,7 @@ class PythonDescriptorPoolWrapper {
private:
bool CopyToFileDescriptorProto(py::handle py_file_descriptor,
FileDescriptorProto* output) {
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
if (GlobalState::instance()->py_proto_api()) {
try {
py::object c_proto = py::reinterpret_steal<py::object>(
Expand All @@ -552,6 +561,7 @@ class PythonDescriptorPoolWrapper {
PyErr_Print();
}
}
#endif

return output->ParsePartialFromString(
PyBytesAsStringView(py_file_descriptor.attr("serialized_pb")));
Expand Down Expand Up @@ -750,6 +760,7 @@ py::handle GenericPyProtoCast(Message* src, py::return_value_policy policy,
return py_proto.release();
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy,
py::handle parent, bool is_const) {
assert(policy != pybind11::return_value_policy::automatic);
Expand Down Expand Up @@ -823,6 +834,7 @@ py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy,
throw py::cast_error(message + ReturnValuePolicyName(policy));
}
}
#endif

py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
py::handle parent, bool is_const) {
Expand All @@ -833,6 +845,9 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// 1. The binary does not have a py_proto_api instance, or
// 2. a) the proto is from the default pool and
// b) the binary is not using fast_cpp_protos.
#if ! defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
return GenericPyProtoCast(src, policy, parent, is_const);
#else
if (GlobalState::instance()->py_proto_api() == nullptr ||
(src->GetDescriptor()->file()->pool() ==
DescriptorPool::generated_pool() &&
Expand Down Expand Up @@ -861,6 +876,7 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// construct a mapping between C++ pool() and python pool(), and then
// use the PyProto_API to make it work.
return GenericFastCppProtoCast(src, policy, parent, is_const);
#endif
}

} // namespace pybind11_protobuf

0 comments on commit 29997d6

Please sign in to comment.