Skip to content

Commit

Permalink
Add `pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFi…
Browse files Browse the repository at this point in the history
…eldsPolicy`.

PiperOrigin-RevId: 595557800
  • Loading branch information
Ralf W. Grosse-Kunstleve authored and copybara-github committed Jan 4, 2024
1 parent 2f613ca commit 3b11990
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 29 deletions.
48 changes: 29 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,18 @@ protobuf extensions
are involved, a well-known pitfall is that extensions are silently moved
to the `proto2::UnknownFieldSet` when a message is deserialized in C++,
but the `cc_proto_library` for the extensions is not linked in. The root
cause is an asymmetry in the handling of Python protos vs C++ protos: when
a Python proto is deserialized, both the Python descriptor pool and the C++
descriptor pool are inspected, but when a C++ proto is deserialized, only
cause is an asymmetry in the handling of Python protos vs C++
protos:
when a Python proto is deserialized, both the Python descriptor pool and the
C++ descriptor pool are inspected, but when a C++ proto is deserialized, only
the C++ descriptor pool is inspected. Until this asymmetry is resolved, the
`cc_proto_library` for all extensions involved must be added to the `deps` of
the relevant `pybind_library` or `pybind_extension`, but this is sufficiently
unobvious to be a setup for regular accidents, potentially with critical
the relevant `pybind_library` or `pybind_extension`, or if this is impractial,
`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse`
or `pybind11_protobuf::AllowUnknownFieldsFor` can be used.

The pitfall is sufficiently unobvious to be a setup for regular accidents,
potentially with critical
consequences.

To guard against the most common type of accident, native_proto_caster.h
Expand All @@ -97,14 +102,20 @@ in certain situations:
* and the `proto2::UnknownFieldSet` for the message or any of its submessages
is not empty.

`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in
which
`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse`
is a **global** escape hatch trading off convenience and runtime overhead: the
convenience is that it is not necessary to determine what `cc_proto_library`
dependencies need to be added, the runtime overhead is that
`SerializePartialToString`/`ParseFromString` is used for messages with unknown
fields, instead of the much faster `CopyFrom`.

* unknown fields existed before the safety mechanism was
introduced.
* unknown fields are needed in the future.
Another escape hatch is `pybind11_protobuf::AllowUnknownFieldsFor`, which
simply disables the safety mechanism for **specific message types**, without
a runtime overhead. This is useful for situations in which unknown fields
are acceptable.

An example of a full error message (with lines breaks here for readability):
An example of a full error message generated by the safety mechanism
(with lines breaks here for readability):

```
Proto Message of type pybind11.test.NestRepeated has an Unknown Field with
Expand All @@ -117,14 +128,13 @@ Only if there is no alternative to suppressing this error, use
(Warning: suppressions may mask critical bugs.)
```

The current implementation is a compromise solution, trading off simplicity
of implementation, runtime performance, and precision. Generally, the runtime
overhead is expected to be very small, but fields flagged as unknown may not
necessarily be in extensions.
Alerting developers of new code to unknown fields is assumed to be generally
helpful, but the unknown fields detection is limited to messages with
extensions, to avoid the runtime overhead for the presumably much more common
case that no extensions are involved.
Note that the current implementation of the safety mechanism is a compromise
solution, trading off simplicity of implementation, runtime performance,
and precision. Alerting developers of new code to unknown fields is assumed
to be generally helpful, but the unknown fields detection is limited to
messages with extensions, to avoid the runtime overhead for the presumably
much more common case that no extensions are involved. Because of this,
the runtime overhead for the safety mechanism is expected to be very small.

### Enumerations

Expand Down
12 changes: 12 additions & 0 deletions pybind11_protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,15 @@ cc_library(
"@com_google_protobuf//python:proto_api",
],
)

cc_library(
name = "disallow_extensions_with_unknown_fields",
srcs = ["disallow_extensions_with_unknown_fields.cc"],
visibility = [
"//visibility:public",
],
deps = [
":check_unknown_fields",
],
alwayslink = 1,
)
7 changes: 5 additions & 2 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
unknown_field_parent_message_fqn));
}

std::optional<std::string> CheckAndBuildErrorMessageIfAny(
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
const ::google::protobuf::Message* message, bool build_error_message_if_any) {
const auto* root_descriptor = message->GetDescriptor();
HasUnknownFields search{py_proto_api, root_descriptor};
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
Expand All @@ -193,6 +193,9 @@ std::optional<std::string> CheckAndBuildErrorMessageIfAny(
search.FieldFQN())) != 0) {
return std::nullopt;
}
if (!build_error_message_if_any) {
return ""; // This indicates that an unknown field was found.
}
return search.BuildErrorMessage();
}

Expand Down
37 changes: 35 additions & 2 deletions pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,45 @@

namespace pybind11_protobuf::check_unknown_fields {

class ExtensionsWithUnknownFieldsPolicy {
enum State {
// Initial state.
kWeakDisallow,

// Primary use case: PyCLIF extensions might set this when being imported.
kWeakEnableFallbackToSerializeParse,

// Primary use case: `:disallow_extensions_with_unknown_fields` in `deps`
// of a binary (or test).
kStrongDisallow
};

static State& GetStateSingleton() {
static State singleton = kWeakDisallow;
return singleton;
}

public:
static void WeakEnableFallbackToSerializeParse() {
State& policy = GetStateSingleton();
if (policy == kWeakDisallow) {
policy = kWeakEnableFallbackToSerializeParse;
}
}

static void StrongSetDisallow() { GetStateSingleton() = kStrongDisallow; }

static bool UnknownFieldsAreDisallowed() {
return GetStateSingleton() != kWeakEnableFallbackToSerializeParse;
}
};

void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn);

std::optional<std::string> CheckAndBuildErrorMessageIfAny(
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message);
const ::google::protobuf::Message* top_message, bool build_error_message_if_any);

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
12 changes: 12 additions & 0 deletions pybind11_protobuf/disallow_extensions_with_unknown_fields.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "pybind11_protobuf/check_unknown_fields.h"

namespace pybind11_protobuf::check_unknown_fields {
namespace {

static int kSetConfigDone = []() {
ExtensionsWithUnknownFieldsPolicy::StrongSetDisallow();
return 0;
}();

} // namespace
} // namespace pybind11_protobuf::check_unknown_fields
18 changes: 12 additions & 6 deletions pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -819,18 +819,24 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// 1. The binary does not have a py_proto_api instance, or
// 2. a) the proto is from the default pool and
// b) the binary is not using fast_cpp_protos.
if ((GlobalState::instance()->py_proto_api() == nullptr) ||
if (GlobalState::instance()->py_proto_api() == nullptr ||
(src->GetDescriptor()->file()->pool() ==
DescriptorPool::generated_pool() &&
!GlobalState::instance()->using_fast_cpp())) {
return GenericPyProtoCast(src, policy, parent, is_const);
}

std::optional<std::string> emsg =
check_unknown_fields::CheckAndBuildErrorMessageIfAny(
GlobalState::instance()->py_proto_api(), src);
if (emsg) {
throw py::value_error(*emsg);
std::optional<std::string> unknown_field_message =
check_unknown_fields::CheckRecursively(
GlobalState::instance()->py_proto_api(), src,
check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
UnknownFieldsAreDisallowed());
if (unknown_field_message) {
if (!unknown_field_message->empty()) {
throw py::value_error(*unknown_field_message);
}
// Fall back to serialize/parse.
return GenericPyProtoCast(src, policy, parent, is_const);
}

// If this is a dynamically generated proto, then we're going to need to
Expand Down

0 comments on commit 3b11990

Please sign in to comment.