Skip to content

Commit 02b1e6f

Browse files
authored
Refactor encryption executor (#2011)
1 parent 138517e commit 02b1e6f

File tree

16 files changed

+698
-117
lines changed

16 files changed

+698
-117
lines changed

src/confluent_kafka/schema_registry/_async/avro.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
17+
import io
1818
from json import loads
1919
from typing import Dict, Union, Optional, Callable
2020

@@ -29,6 +29,8 @@
2929
AsyncSchemaRegistryClient,
3030
prefix_schema_id_serializer,
3131
dual_schema_id_deserializer)
32+
from confluent_kafka.schema_registry.common.schema_registry_client import \
33+
RulePhase
3234
from confluent_kafka.serialization import (SerializationError,
3335
SerializationContext)
3436
from confluent_kafka.schema_registry.rule_registry import RuleRegistry
@@ -348,8 +350,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
348350
with _ContextStringIO() as fo:
349351
# write the record to the rest of the buffer
350352
schemaless_writer(fo, parsed_schema, value)
353+
buffer = fo.getvalue()
354+
355+
if latest_schema is not None:
356+
buffer = self._execute_rules_with_phase(
357+
ctx, subject, RulePhase.ENCODING, RuleMode.WRITE,
358+
None, latest_schema.schema, buffer, None, None)
351359

352-
return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id)
360+
return self._schema_id_serializer(buffer, ctx, self._schema_id)
353361

354362
async def _get_parsed_schema(self, schema: Schema) -> AvroSchema:
355363
parsed_schema = self._parsed_schemas.get_parsed_schema(schema)
@@ -557,6 +565,12 @@ async def __deserialize(
557565
if subject is not None:
558566
latest_schema = await self._get_reader_schema(subject)
559567

568+
payload = self._execute_rules_with_phase(
569+
ctx, subject, RulePhase.ENCODING, RuleMode.READ,
570+
None, writer_schema_raw, payload, None, None)
571+
if isinstance(payload, bytes):
572+
payload = io.BytesIO(payload)
573+
560574
if latest_schema is not None:
561575
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
562576
reader_schema_raw = latest_schema.schema

src/confluent_kafka/schema_registry/_async/json_schema.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
17+
import io
1818
import json
1919
from typing import Union, Optional, Tuple, Callable
2020

@@ -33,6 +33,8 @@
3333
from confluent_kafka.schema_registry.common.json_schema import (
3434
DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO, JSON_TYPE
3535
)
36+
from confluent_kafka.schema_registry.common.schema_registry_client import \
37+
RulePhase
3638
from confluent_kafka.schema_registry.rule_registry import RuleRegistry
3739
from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \
3840
ParsedSchemaCache, SchemaId
@@ -374,8 +376,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
374376
if isinstance(encoded_value, str):
375377
encoded_value = encoded_value.encode("utf8")
376378
fo.write(encoded_value)
379+
buffer = fo.getvalue()
380+
381+
if latest_schema is not None:
382+
buffer = self._execute_rules_with_phase(
383+
ctx, subject, RulePhase.ENCODING, RuleMode.WRITE,
384+
None, latest_schema.schema, buffer, None, None)
377385

378-
return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id)
386+
return self._schema_id_serializer(buffer, ctx, self._schema_id)
379387

380388
async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]:
381389
if schema is None:
@@ -552,9 +560,9 @@ async def __init_impl(
552560
__init__ = __init_impl
553561

554562
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
555-
return self.__serialize(data, ctx)
563+
return self.__deserialize(data, ctx)
556564

557-
async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
565+
async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
558566
"""
559567
Deserialize a JSON encoded record with Confluent Schema Registry framing to
560568
a dict, or object instance according to from_dict if from_dict is specified.
@@ -583,9 +591,6 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
583591
schema_id = SchemaId(JSON_TYPE)
584592
payload = self._schema_id_deserializer(data, ctx, schema_id)
585593

586-
# JSON documents are self-describing; no need to query schema
587-
obj_dict = self._json_decode(payload.read())
588-
589594
if self._registry is not None:
590595
writer_schema_raw = await self._get_writer_schema(schema_id, subject)
591596
writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw)
@@ -597,6 +602,15 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
597602
writer_schema_raw = None
598603
writer_schema, writer_ref_registry = None, None
599604

605+
payload = self._execute_rules_with_phase(
606+
ctx, subject, RulePhase.ENCODING, RuleMode.READ,
607+
None, writer_schema_raw, payload, None, None)
608+
if isinstance(payload, bytes):
609+
payload = io.BytesIO(payload)
610+
611+
# JSON documents are self-describing; no need to query schema
612+
obj_dict = self._json_decode(payload.read())
613+
600614
if latest_schema is not None:
601615
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
602616
reader_schema_raw = latest_schema.schema

src/confluent_kafka/schema_registry/_async/protobuf.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from confluent_kafka.schema_registry import (reference_subject_name_strategy,
2828
topic_subject_name_strategy,
2929
prefix_schema_id_serializer, dual_schema_id_deserializer)
30+
from confluent_kafka.schema_registry.common.schema_registry_client import \
31+
RulePhase
3032
from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient
3133
from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \
3234
_init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO, PROTOBUF_TYPE
@@ -426,7 +428,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
426428
with _ContextStringIO() as fo:
427429
fo.write(message.SerializeToString())
428430
self._schema_id.message_indexes = self._index_array
429-
return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id)
431+
buffer = fo.getvalue()
432+
433+
if latest_schema is not None:
434+
buffer = self._execute_rules_with_phase(
435+
ctx, subject, RulePhase.ENCODING, RuleMode.WRITE,
436+
None, latest_schema.schema, buffer, None, None)
437+
438+
return self._schema_id_serializer(buffer, ctx, self._schema_id)
430439

431440
async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]:
432441
result = self._parsed_schemas.get_parsed_schema(schema)
@@ -549,9 +558,9 @@ async def __init_impl(
549558
__init__ = __init_impl
550559

551560
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
552-
return self.__serialize(data, ctx)
561+
return self.__deserialize(data, ctx)
553562

554-
async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
563+
async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
555564
"""
556565
Deserialize a serialized protobuf message with Confluent Schema Registry
557566
framing.
@@ -596,6 +605,12 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
596605
writer_schema_raw = None
597606
writer_schema = None
598607

608+
payload = self._execute_rules_with_phase(
609+
ctx, subject, RulePhase.ENCODING, RuleMode.READ,
610+
None, writer_schema_raw, payload, None, None)
611+
if isinstance(payload, bytes):
612+
payload = io.BytesIO(payload)
613+
599614
if latest_schema is not None:
600615
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
601616
reader_schema_raw = latest_schema.schema

src/confluent_kafka/schema_registry/_async/serde.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typing import List, Optional, Set, Dict, Any
2121

2222
from confluent_kafka.schema_registry import RegisteredSchema
23+
from confluent_kafka.schema_registry.common.schema_registry_client import \
24+
RulePhase
2325
from confluent_kafka.schema_registry.common.serde import ErrorAction, \
2426
FieldTransformer, Migration, NoneAction, RuleAction, \
2527
RuleConditionError, RuleContext, RuleError, SchemaId
@@ -59,6 +61,17 @@ def _execute_rules(
5961
source: Optional[Schema], target: Optional[Schema],
6062
message: Any, inline_tags: Optional[Dict[str, Set[str]]],
6163
field_transformer: Optional[FieldTransformer]
64+
) -> Any:
65+
return self._execute_rules_with_phase(
66+
ser_ctx, subject, RulePhase.DOMAIN, rule_mode,
67+
source, target, message, inline_tags, field_transformer)
68+
69+
def _execute_rules_with_phase(
70+
self, ser_ctx: SerializationContext, subject: str,
71+
rule_phase: RulePhase, rule_mode: RuleMode,
72+
source: Optional[Schema], target: Optional[Schema],
73+
message: Any, inline_tags: Optional[Dict[str, Set[str]]],
74+
field_transformer: Optional[FieldTransformer]
6275
) -> Any:
6376
if message is None or target is None:
6477
return message
@@ -73,7 +86,10 @@ def _execute_rules(
7386
rules.reverse()
7487
else:
7588
if target is not None and target.rule_set is not None:
76-
rules = target.rule_set.domain_rules
89+
if rule_phase == RulePhase.ENCODING:
90+
rules = target.rule_set.encoding_rules
91+
else:
92+
rules = target.rule_set.domain_rules
7793
if rule_mode == RuleMode.READ:
7894
# Execute read rules in reverse order for symmetry
7995
rules = rules[:] if rules else []
@@ -197,19 +213,25 @@ async def _get_writer_schema(
197213
else:
198214
raise SerializationError("Schema ID or GUID is not set")
199215

200-
def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool:
216+
def _has_rules(self, rule_set: RuleSet, phase: RulePhase, mode: RuleMode) -> bool:
201217
if rule_set is None:
202218
return False
219+
if phase == RulePhase.MIGRATION:
220+
rules = rule_set.migration_rules
221+
elif phase == RulePhase.DOMAIN:
222+
rules = rule_set.domain_rules
223+
elif phase == RulePhase.ENCODING:
224+
rules = rule_set.encoding_rules
203225
if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE):
204226
return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN
205-
for rule in rule_set.migration_rules or [])
227+
for rule in rules or [])
206228
elif mode == RuleMode.UPDOWN:
207-
return any(rule.mode == mode for rule in rule_set.migration_rules or [])
229+
return any(rule.mode == mode for rule in rules or [])
208230
elif mode in (RuleMode.WRITE, RuleMode.READ):
209231
return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD
210-
for rule in rule_set.domain_rules or [])
232+
for rule in rules or [])
211233
elif mode == RuleMode.WRITEREAD:
212-
return any(rule.mode == mode for rule in rule_set.migration_rules or [])
234+
return any(rule.mode == mode for rule in rules or [])
213235
return False
214236

215237
async def _get_migrations(
@@ -235,7 +257,8 @@ async def _get_migrations(
235257
if i == 0:
236258
previous = version
237259
continue
238-
if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode):
260+
if version.schema.rule_set is not None and self._has_rules(
261+
version.schema.rule_set, RulePhase.MIGRATION, migration_mode):
239262
if migration_mode == RuleMode.UPGRADE:
240263
migration = Migration(migration_mode, previous, version)
241264
else:
@@ -265,7 +288,8 @@ def _execute_migrations(
265288
migrations: List[Migration], message: Any
266289
) -> Any:
267290
for migration in migrations:
268-
message = self._execute_rules(ser_ctx, subject, migration.rule_mode,
269-
migration.source.schema, migration.target.schema,
270-
message, None, None)
291+
message = self._execute_rules_with_phase(
292+
ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode,
293+
migration.source.schema, migration.target.schema,
294+
message, None, None)
271295
return message

src/confluent_kafka/schema_registry/_sync/avro.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
17+
import io
1818
from json import loads
1919
from typing import Dict, Union, Optional, Callable
2020

@@ -29,6 +29,8 @@
2929
SchemaRegistryClient,
3030
prefix_schema_id_serializer,
3131
dual_schema_id_deserializer)
32+
from confluent_kafka.schema_registry.common.schema_registry_client import \
33+
RulePhase
3234
from confluent_kafka.serialization import (SerializationError,
3335
SerializationContext)
3436
from confluent_kafka.schema_registry.rule_registry import RuleRegistry
@@ -348,8 +350,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
348350
with _ContextStringIO() as fo:
349351
# write the record to the rest of the buffer
350352
schemaless_writer(fo, parsed_schema, value)
353+
buffer = fo.getvalue()
354+
355+
if latest_schema is not None:
356+
buffer = self._execute_rules_with_phase(
357+
ctx, subject, RulePhase.ENCODING, RuleMode.WRITE,
358+
None, latest_schema.schema, buffer, None, None)
351359

352-
return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id)
360+
return self._schema_id_serializer(buffer, ctx, self._schema_id)
353361

354362
def _get_parsed_schema(self, schema: Schema) -> AvroSchema:
355363
parsed_schema = self._parsed_schemas.get_parsed_schema(schema)
@@ -557,6 +565,12 @@ def __deserialize(
557565
if subject is not None:
558566
latest_schema = self._get_reader_schema(subject)
559567

568+
payload = self._execute_rules_with_phase(
569+
ctx, subject, RulePhase.ENCODING, RuleMode.READ,
570+
None, writer_schema_raw, payload, None, None)
571+
if isinstance(payload, bytes):
572+
payload = io.BytesIO(payload)
573+
560574
if latest_schema is not None:
561575
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
562576
reader_schema_raw = latest_schema.schema

src/confluent_kafka/schema_registry/_sync/json_schema.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
17+
import io
1818
import json
1919
from typing import Union, Optional, Tuple, Callable
2020

@@ -33,6 +33,8 @@
3333
from confluent_kafka.schema_registry.common.json_schema import (
3434
DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO, JSON_TYPE
3535
)
36+
from confluent_kafka.schema_registry.common.schema_registry_client import \
37+
RulePhase
3638
from confluent_kafka.schema_registry.rule_registry import RuleRegistry
3739
from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \
3840
ParsedSchemaCache, SchemaId
@@ -374,8 +376,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
374376
if isinstance(encoded_value, str):
375377
encoded_value = encoded_value.encode("utf8")
376378
fo.write(encoded_value)
379+
buffer = fo.getvalue()
380+
381+
if latest_schema is not None:
382+
buffer = self._execute_rules_with_phase(
383+
ctx, subject, RulePhase.ENCODING, RuleMode.WRITE,
384+
None, latest_schema.schema, buffer, None, None)
377385

378-
return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id)
386+
return self._schema_id_serializer(buffer, ctx, self._schema_id)
379387

380388
def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]:
381389
if schema is None:
@@ -552,9 +560,9 @@ def __init_impl(
552560
__init__ = __init_impl
553561

554562
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
555-
return self.__serialize(data, ctx)
563+
return self.__deserialize(data, ctx)
556564

557-
def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
565+
def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
558566
"""
559567
Deserialize a JSON encoded record with Confluent Schema Registry framing to
560568
a dict, or object instance according to from_dict if from_dict is specified.
@@ -583,9 +591,6 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -
583591
schema_id = SchemaId(JSON_TYPE)
584592
payload = self._schema_id_deserializer(data, ctx, schema_id)
585593

586-
# JSON documents are self-describing; no need to query schema
587-
obj_dict = self._json_decode(payload.read())
588-
589594
if self._registry is not None:
590595
writer_schema_raw = self._get_writer_schema(schema_id, subject)
591596
writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw)
@@ -597,6 +602,15 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -
597602
writer_schema_raw = None
598603
writer_schema, writer_ref_registry = None, None
599604

605+
payload = self._execute_rules_with_phase(
606+
ctx, subject, RulePhase.ENCODING, RuleMode.READ,
607+
None, writer_schema_raw, payload, None, None)
608+
if isinstance(payload, bytes):
609+
payload = io.BytesIO(payload)
610+
611+
# JSON documents are self-describing; no need to query schema
612+
obj_dict = self._json_decode(payload.read())
613+
600614
if latest_schema is not None:
601615
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
602616
reader_schema_raw = latest_schema.schema

0 commit comments

Comments
 (0)