3030#include " common/type_introspector.h"
3131#include " internal/status_macros.h"
3232#include " google/protobuf/arena.h"
33+ #include " google/protobuf/descriptor.h"
3334
3435namespace cel ::checker_internal {
3536
@@ -59,8 +60,21 @@ absl::Nullable<const FunctionDecl*> TypeCheckEnv::LookupFunction(
5960
6061absl::StatusOr<absl::optional<Type>> TypeCheckEnv::LookupTypeName (
6162 TypeFactory& type_factory, absl::string_view name) const {
63+ {
64+ // Check the descriptor pool first, then fallback to custom type providers.
65+ absl::Nullable<const google::protobuf::Descriptor*> descriptor =
66+ descriptor_pool_->FindMessageTypeByName (name);
67+ if (descriptor != nullptr ) {
68+ return Type::Message (descriptor);
69+ }
70+ absl::Nullable<const google::protobuf::EnumDescriptor*> enum_descriptor =
71+ descriptor_pool_->FindEnumTypeByName (name);
72+ if (enum_descriptor != nullptr ) {
73+ return Type::Enum (enum_descriptor);
74+ }
75+ }
6276 const TypeCheckEnv* scope = this ;
63- while (scope != nullptr ) {
77+ do {
6478 for (auto iter = type_providers_.rbegin (); iter != type_providers_.rend ();
6579 ++iter) {
6680 auto type = (*iter)->FindType (type_factory, name);
@@ -69,15 +83,34 @@ absl::StatusOr<absl::optional<Type>> TypeCheckEnv::LookupTypeName(
6983 }
7084 }
7185 scope = scope->parent_ ;
72- }
86+ } while ((scope != nullptr ));
7387 return absl::nullopt ;
7488}
7589
7690absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupEnumConstant (
7791 TypeFactory& type_factory, absl::string_view type,
7892 absl::string_view value) const {
93+ {
94+ // Check the descriptor pool first, then fallback to custom type providers.
95+ absl::Nullable<const google::protobuf::EnumDescriptor*> enum_descriptor =
96+ descriptor_pool_->FindEnumTypeByName (type);
97+ if (enum_descriptor != nullptr ) {
98+ absl::Nullable<const google::protobuf::EnumValueDescriptor*> enum_value_descriptor =
99+ enum_descriptor->FindValueByName (value);
100+ if (enum_value_descriptor == nullptr ) {
101+ return absl::nullopt ;
102+ }
103+ auto decl =
104+ MakeVariableDecl (absl::StrCat (enum_descriptor->full_name (), " ." ,
105+ enum_value_descriptor->name ()),
106+ Type::Enum (enum_descriptor));
107+ decl.set_value (
108+ Constant (static_cast <int64_t >(enum_value_descriptor->number ())));
109+ return decl;
110+ }
111+ }
79112 const TypeCheckEnv* scope = this ;
80- while (scope != nullptr ) {
113+ do {
81114 for (auto iter = type_providers_.rbegin (); iter != type_providers_.rend ();
82115 ++iter) {
83116 auto enum_constant = (*iter)->FindEnumConstant (type_factory, type, value);
@@ -95,7 +128,7 @@ absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupEnumConstant(
95128 }
96129 }
97130 scope = scope->parent_ ;
98- }
131+ } while (scope != nullptr );
99132 return absl::nullopt ;
100133}
101134
@@ -122,8 +155,25 @@ absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupTypeConstant(
122155absl::StatusOr<absl::optional<StructTypeField>> TypeCheckEnv::LookupStructField (
123156 TypeFactory& type_factory, absl::string_view type_name,
124157 absl::string_view field_name) const {
158+ {
159+ // Check the descriptor pool first, then fallback to custom type providers.
160+ absl::Nullable<const google::protobuf::Descriptor*> descriptor =
161+ descriptor_pool_->FindMessageTypeByName (type_name);
162+ if (descriptor != nullptr ) {
163+ absl::Nullable<const google::protobuf::FieldDescriptor*> field_descriptor =
164+ descriptor->FindFieldByName (field_name);
165+ if (field_descriptor == nullptr ) {
166+ field_descriptor = descriptor_pool_->FindExtensionByPrintableName (
167+ descriptor, field_name);
168+ if (field_descriptor == nullptr ) {
169+ return absl::nullopt ;
170+ }
171+ }
172+ return cel::MessageTypeField (field_descriptor);
173+ }
174+ }
125175 const TypeCheckEnv* scope = this ;
126- while (scope != nullptr ) {
176+ do {
127177 // Check the type providers in reverse registration order.
128178 // Note: this doesn't allow for shadowing a type with a subset type of the
129179 // same name -- the parent type provider will still be considered when
@@ -137,7 +187,7 @@ absl::StatusOr<absl::optional<StructTypeField>> TypeCheckEnv::LookupStructField(
137187 }
138188 }
139189 scope = scope->parent_ ;
140- }
190+ } while (scope != nullptr );
141191 return absl::nullopt ;
142192}
143193
0 commit comments