Skip to content

Commit

Permalink
Remove two pool-membership conditions guarding the C++ equivalent of …
Browse files Browse the repository at this point in the history
…`obj.SerializePartialToString()`

The main change in this CL is to remove two conditions in `PyProtoIsCompatible()`:

1.

```
  if (descriptor->file()->pool() != DescriptorPool::generated_pool()) {
```

2.

```
    return py_pool->is(GlobalState::instance()->global_pool());
```

Rationale for removing these conditions:

* All that matters for protobuf compatibility is that the `full_name` is the same. (Thanks @kmoffett for that insight!)

* Cross-extension-module ABI compatibility is not a concern because only the Python API is used in the relevant code paths serializing Python protobuf objects to Python `bytes` (equivalent to calling `obj.SerializePartialToString()` from Python).

All other changes in this CL are secondary: small-scale refactoring, slight naming changes, additional tests for error conditions.

PiperOrigin-RevId: 589898116
  • Loading branch information
Ralf W. Grosse-Kunstleve authored and copybara-github committed Dec 11, 2023
1 parent b713501 commit 8359a09
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 78 deletions.
105 changes: 43 additions & 62 deletions pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "absl/strings/numbers.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "pybind11_protobuf/check_unknown_fields.h"

Expand Down Expand Up @@ -534,10 +535,8 @@ class PythonDescriptorPoolWrapper {
}
}

py::object wire = py_file_descriptor.attr("serialized_pb");
const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr());
return output->ParsePartialFromArray(bytes,
PYBIND11_BYTES_SIZE(wire.ptr()));
return output->ParsePartialFromString(
PyBytesAsStringView(py_file_descriptor.attr("serialized_pb")));
}

py::object pool_; // never dereferenced.
Expand All @@ -549,6 +548,11 @@ class PythonDescriptorPoolWrapper {

} // namespace

absl::string_view PyBytesAsStringView(py::bytes py_bytes) {
return absl::string_view(PyBytes_AsString(py_bytes.ptr()),
PyBytes_Size(py_bytes.ptr()));
}

void InitializePybindProtoCastUtil() {
assert(PyGILState_Check());
GlobalState::instance();
Expand Down Expand Up @@ -593,7 +597,7 @@ const Message* PyProtoGetCppMessagePointer(py::handle src) {
#endif
}

absl::optional<std::string> PyProtoDescriptorName(py::handle py_proto) {
absl::optional<std::string> PyProtoDescriptorFullName(py::handle py_proto) {
assert(PyGILState_Check());
auto py_full_name = ResolveAttrs(py_proto, {"DESCRIPTOR", "full_name"});
if (py_full_name) {
Expand All @@ -602,66 +606,42 @@ absl::optional<std::string> PyProtoDescriptorName(py::handle py_proto) {
return absl::nullopt;
}

bool PyProtoIsCompatible(py::handle py_proto, const Descriptor* descriptor) {
assert(PyGILState_Check());
if (descriptor->file()->pool() != DescriptorPool::generated_pool()) {
/// This indicates that the C++ descriptor does not come from the C++
/// DescriptorPool. This may happen if the C++ code has the same proto
/// in different descriptor pools, perhaps from different shared objects,
/// and could be result in undefined behavior.
return false;
}

auto py_descriptor = ResolveAttrs(py_proto, {"DESCRIPTOR"});
if (!py_descriptor) {
// Not a valid protobuf -- missing DESCRIPTOR.
return false;
}

// Test full_name equivalence.
{
auto py_full_name = ResolveAttrs(*py_descriptor, {"full_name"});
if (!py_full_name) {
// Not a valid protobuf -- missing DESCRIPTOR.full_name
return false;
}
auto full_name = CastToOptionalString(*py_full_name);
if (!full_name || *full_name != descriptor->full_name()) {
// Name mismatch.
return false;
}
}

// The C++ descriptor is compiled in (see above assert), so the py_proto
// is expected to be from the global pool, i.e. the DESCRIPTOR.file.pool
// instance is the global python pool, and not a custom pool.
auto py_pool = ResolveAttrs(*py_descriptor, {"file", "pool"});
if (py_pool) {
return py_pool->is(GlobalState::instance()->global_pool());
}

// The py_proto is missing a DESCRIPTOR.file.pool, but the name matches.
// This will not happen with a native python implementation, but does
// occur with the deprecated :proto_casters, and could happen with other
// mocks. Returning true allows the caster to call PyProtoCopyToCProto.
return true;
bool PyProtoHasMatchingFullName(py::handle py_proto,
const Descriptor* descriptor) {
auto full_name = PyProtoDescriptorFullName(py_proto);
return full_name && *full_name == descriptor->full_name();
}

bool PyProtoCopyToCProto(py::handle py_proto, Message* message) {
assert(PyGILState_Check());
auto serialize_fn = ResolveAttrMRO(py_proto, "SerializePartialToString");
py::bytes PyProtoSerializePartialToString(py::handle py_proto,
bool raise_if_error) {
static const char* serialize_fn_name = "SerializePartialToString";
auto serialize_fn = ResolveAttrMRO(py_proto, serialize_fn_name);
if (!serialize_fn) {
throw py::type_error(
"SerializePartialToString method not found; is this a " +
message->GetDescriptor()->full_name());
}
auto wire = (*serialize_fn)();
const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr());
if (!bytes) {
throw py::type_error("SerializePartialToString failed; is this a " +
message->GetDescriptor()->full_name());
return py::object();
}
auto serialized_bytes = py::reinterpret_steal<py::object>(
PyObject_CallObject(serialize_fn->ptr(), nullptr));
if (!serialized_bytes) {
if (raise_if_error) {
std::string msg = py::repr(py_proto).cast<std::string>() + "." +
serialize_fn_name + "() function call FAILED";
py::raise_from(PyExc_TypeError, msg.c_str());
throw py::error_already_set();
}
return py::object();
}
if (!PyBytes_Check(serialized_bytes.ptr())) {
if (raise_if_error) {
std::string msg = py::repr(py_proto).cast<std::string>() + "." +
serialize_fn_name +
"() function call is expected to return bytes, but the "
"returned value is " +
py::repr(serialized_bytes).cast<std::string>();
throw py::type_error(msg);
}
return py::object();
}
return message->ParsePartialFromArray(bytes, PYBIND11_BYTES_SIZE(wire.ptr()));
return serialized_bytes;
}

void CProtoCopyToPyProto(Message* message, py::handle py_proto) {
Expand All @@ -686,7 +666,8 @@ std::unique_ptr<Message> AllocateCProtoFromPythonSymbolDatabase(
assert(PyGILState_Check());
auto pool = ResolveAttrs(src, {"DESCRIPTOR", "file", "pool"});
if (!pool) {
throw py::type_error("Object is not a valid protobuf");
throw py::type_error(py::repr(src).cast<std::string>() +
" object is not a valid protobuf");
}

auto pool_data =
Expand Down
22 changes: 14 additions & 8 deletions pybind11_protobuf/proto_cast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"

// PYBIND11_PROTOBUF_ASSUME_FULL_ABI_COMPATIBILITY can be defined by users
Expand All @@ -28,6 +29,10 @@

namespace pybind11_protobuf {

// Simple helper. Caller has to ensure that the py_bytes argument outlives the
// returned string_view.
absl::string_view PyBytesAsStringView(pybind11::bytes py_bytes);

// Initialize internal proto cast dependencies, which includes importing
// various protobuf-related modules.
void InitializePybindProtoCastUtil();
Expand All @@ -39,22 +44,23 @@ void ImportProtoDescriptorModule(const ::google::protobuf::Descriptor *);
const ::google::protobuf::Message *PyProtoGetCppMessagePointer(pybind11::handle src);

// Returns the protocol buffer's py_proto.DESCRIPTOR.full_name attribute.
absl::optional<std::string> PyProtoDescriptorName(pybind11::handle py_proto);
absl::optional<std::string> PyProtoDescriptorFullName(
pybind11::handle py_proto);

// Returns true if py_proto full name matches descriptor full name.
bool PyProtoHasMatchingFullName(pybind11::handle py_proto,
const ::google::protobuf::Descriptor *descriptor);

// Return whether py_proto is compatible with the C++ descriptor.
// The py_proto name must match the C++ Descriptor::full_name(), and is
// expected to originate from the python default pool, which means that
// this method will return false for dynamic protos.
bool PyProtoIsCompatible(pybind11::handle py_proto,
const ::google::protobuf::Descriptor *descriptor);
// Caller should enforce any type identity that is required.
pybind11::bytes PyProtoSerializePartialToString(pybind11::handle py_proto,
bool raise_if_error);

// Allocates a C++ protocol buffer for a given name.
std::unique_ptr<::google::protobuf::Message> AllocateCProtoFromPythonSymbolDatabase(
pybind11::handle src, const std::string &full_name);

// Serialize the py_proto and deserialize it into the provided message.
// Caller should enforce any type identity that is required.
bool PyProtoCopyToCProto(pybind11::handle py_proto, ::google::protobuf::Message *message);
void CProtoCopyToPyProto(::google::protobuf::Message *message, pybind11::handle py_proto);

// Returns a handle to a python protobuf suitably
Expand Down
26 changes: 18 additions & 8 deletions pybind11_protobuf/proto_caster_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,19 @@ struct proto_caster_load_impl {
}
}

// The incoming object is not a compatible fast_cpp_proto, so check whether
// it is otherwise compatible, then serialize it and deserialize into a
// native C++ proto type.
if (!pybind11_protobuf::PyProtoIsCompatible(src,
ProtoType::GetDescriptor())) {
if (!PyProtoHasMatchingFullName(src, ProtoType::GetDescriptor())) {
return false;
}
pybind11::bytes serialized_bytes =
PyProtoSerializePartialToString(src, convert);
if (!serialized_bytes) {
return false;
}

owned = std::unique_ptr<ProtoType>(new ProtoType());
value = owned.get();
return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get());
return owned.get()->ParsePartialFromString(
PyBytesAsStringView(serialized_bytes));
}

// ensure_owned ensures that the owned member contains a copy of the
Expand Down Expand Up @@ -108,16 +111,23 @@ struct proto_caster_load_impl<::google::protobuf::Message> {

// `src` is not a C++ proto instance from the generated_pool,
// so create a compatible native C++ proto.
auto descriptor_name = pybind11_protobuf::PyProtoDescriptorName(src);
auto descriptor_name = pybind11_protobuf::PyProtoDescriptorFullName(src);
if (!descriptor_name) {
return false;
}
pybind11::bytes serialized_bytes =
PyProtoSerializePartialToString(src, convert);
if (!serialized_bytes) {
return false;
}

owned.reset(static_cast<ProtoType *>(
pybind11_protobuf::AllocateCProtoFromPythonSymbolDatabase(
src, *descriptor_name)
.release()));
value = owned.get();
return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get());
return owned.get()->ParsePartialFromString(
PyBytesAsStringView(serialized_bytes));
}

// ensure_owned ensures that the owned member contains a copy of the
Expand Down
28 changes: 28 additions & 0 deletions pybind11_protobuf/tests/pass_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,34 @@ def test_pass_fake2(self, check_method):
def test_overload_fn(self, message_fn, expected):
self.assertEqual(expected, m.fn_overload(message_fn()))

def test_bad_serialize_partial_function_calls(self):
class FakeDescr:
full_name = 'fake_full_name'

class FakeProto:
DESCRIPTOR = FakeDescr()

def __init__(self, serialize_fn_return_value=None):
self.serialize_fn_return_value = serialize_fn_return_value

def SerializePartialToString(self): # pylint: disable=invalid-name
if self.serialize_fn_return_value is None:
raise RuntimeError('Broken serialize_fn.')
return self.serialize_fn_return_value

with self.assertRaisesRegex(
TypeError, r'\.SerializePartialToString\(\) function call FAILED$'
):
m.fn_overload(FakeProto())
with self.assertRaisesRegex(
TypeError,
r'\.SerializePartialToString\(\) function call is expected to return'
r' bytes, but the returned value is \[\]$',
):
m.fn_overload(FakeProto([]))
with self.assertRaisesRegex(TypeError, r' object is not a valid protobuf$'):
m.fn_overload(FakeProto(b''))


if __name__ == '__main__':
absltest.main()

0 comments on commit 8359a09

Please sign in to comment.