diff --git a/README.md b/README.md index 01f13bc..275d766 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ class Address: number: int zip_code: int city: str - + class PersonInfo: def __init__(self, name: str, age: int, address: Address): self.name = name @@ -181,6 +181,43 @@ print("Target public_info.address is same as source address: ", address is publi * [TortoiseORM](https://github.com/tortoise/tortoise-orm) * [SQLAlchemy](https://www.sqlalchemy.org/) +## Complex mapping support + +Support to pass a model factory method or cunstructor when registering a mapping. +This allows for mapping to value objects or other complex types. + +```python +class SourceEnum(Enum): + VALUE1 = "value1" + VALUE2 = "value2" + VALUE3 = "value3" + +class NameEnum(Enum): + VALUE1 = 1 + VALUE2 = 2 + VALUE3 = 3 + +class ValueEnum(Enum): + A = "value1" + B = "value2" + C = "value3" + + +class ValueObject: + value: str + + def __init__(self, value: Union[float, int, Decimal]): + self.value = str(value) + +mapper.add(SourceEnum, NameEnum, model_factory=lambda x: NameEnum[x.name]) +mapper.map(SourceEnum.VALUE1) # NameEnum.VALUE1 + +mapper.add(ValueEnum, SourceEnum, model_factory=lambda x: SourceEnum(x.value)) +mapper.map(ValueEnum.B) # SourceEnum.VALUE2 + +mapper.to(ValueObject).map(Decimal("42"), model_factory=ValueObject) # ValueObject(42) +``` + ## Pydantic/FastAPI Support Out of the box Pydantic models support: ```python @@ -273,7 +310,7 @@ class PublicUserInfo(Base): id = Column(Integer, primary_key=True) public_name = Column(String) hobbies = Column(String) - + obj = UserInfo( id=2, full_name="Danny DeVito", @@ -304,7 +341,7 @@ class TargetClass: def __init__(self, **kwargs): self.name = kwargs["name"] self.age = kwargs["age"] - + @staticmethod def get_fields(cls): return ["name", "age"] @@ -358,7 +395,7 @@ T = TypeVar("T") def class_has_fields_property(target_cls: Type[T]) -> bool: return callable(getattr(target_cls, "fields", None)) - + mapper.add_spec(class_has_fields_property, lambda t: getattr(t, "fields")()) target_obj = mapper.to(TargetClass).map(source_obj) diff --git a/automapper/mapper.py b/automapper/mapper.py index ea1be9a..189a34c 100644 --- a/automapper/mapper.py +++ b/automapper/mapper.py @@ -62,6 +62,7 @@ def map( skip_none_values: bool = False, fields_mapping: FieldsMap = None, use_deepcopy: bool = True, + model_factory: Optional[Callable[[S], T]] = None, ) -> T: """Produces output object mapped from source object and custom arguments. @@ -72,6 +73,9 @@ def map( Specify dictionary in format {"field_name": value_object}. Defaults to None. use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object. Defaults to True. + model_factory (Callable, optional): Custom factory funtion + factory function that is used to create the target_obj. Called with the + source as parameter. Mutually exclusive with fields_mapping. Defaults to None Raises: CircularReferenceError: Circular references in `source class` object are not allowed yet. @@ -86,13 +90,14 @@ def map( skip_none_values=skip_none_values, custom_mapping=fields_mapping, use_deepcopy=use_deepcopy, + model_factory=model_factory, ) class Mapper: def __init__(self) -> None: """Initializes internal containers""" - self._mappings: Dict[Type[S], Tuple[T, FieldsMap]] = {} # type: ignore [valid-type] + self._mappings: Dict[Type[S], Tuple[T, FieldsMap, Callable[[S], [T]]]] = {} # type: ignore [valid-type] self._class_specs: Dict[Type[T], SpecFunction[T]] = {} # type: ignore [valid-type] self._classifier_specs: Dict[ # type: ignore [valid-type] ClassifierFunction[T], SpecFunction[T] @@ -147,6 +152,7 @@ def add( target_cls: Type[T], override: bool = False, fields_mapping: FieldsMap = None, + model_factory: Optional[Callable[[S], T]] = None, ) -> None: """Adds mapping between object of `source class` to an object of `target class`. @@ -156,7 +162,12 @@ def add( override (bool, optional): Override existing `source class` mapping to use new `target class`. Defaults to False. fields_mapping (FieldsMap, optional): Custom mapping. - Specify dictionary in format {"field_name": value_object}. Defaults to None. + Specify dictionary in format {"field_name": value_objecture_obj}. + Can take a lamdba funtion as argument, that will get the source_cls + as argument. Defaults to None. + model_factory (Callable, optional): Custom factory funtion + factory function that is used to create the target_obj. Called with the + source as parameter. Mutually exclusive with fields_mapping. Defaults to None Raises: DuplicatedRegistrationError: Same mapping for `source class` was added. @@ -168,7 +179,7 @@ def add( raise DuplicatedRegistrationError( f"source_cls {source_cls} was already added for mapping" ) - self._mappings[source_cls] = (target_cls, fields_mapping) + self._mappings[source_cls] = (target_cls, fields_mapping, model_factory) def map( self, @@ -201,7 +212,7 @@ def map( raise MappingError(f"Missing mapping type for input type {obj_type}") obj_type_prefix = f"{obj_type.__name__}." - target_cls, target_cls_field_mappings = self._mappings[obj_type] + target_cls, target_cls_field_mappings, target_cls_model_factory= self._mappings[obj_type] common_fields_mapping = fields_mapping if target_cls_field_mappings: @@ -221,6 +232,8 @@ def map( **fields_mapping, } # merge two dict into one, fields_mapping has priority + + return self._map_common( obj, target_cls, @@ -228,6 +241,7 @@ def map( skip_none_values=skip_none_values, custom_mapping=common_fields_mapping, use_deepcopy=use_deepcopy, + model_factory=target_cls_model_factory, ) def _get_fields(self, target_cls: Type[T]) -> Iterable[str]: @@ -257,7 +271,7 @@ def _map_subobject( raise CircularReferenceError() if type(obj) in self._mappings: - target_cls, _ = self._mappings[type(obj)] + target_cls, _, _ = self._mappings[type(obj)] result: Any = self._map_common( obj, target_cls, _visited_stack, skip_none_values=skip_none_values ) @@ -297,6 +311,7 @@ def _map_common( skip_none_values: bool = False, custom_mapping: FieldsMap = None, use_deepcopy: bool = True, + model_factory: Optional[Callable[[S], T]] = None, ) -> T: """Produces output object mapped from source object and custom arguments. @@ -309,6 +324,9 @@ def _map_common( Specify dictionary in format {"field_name": value_object}. Defaults to None. use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object. Defaults to True. + model_factory (Callable, optional): Custom factory funtion + factory function that is used to create the target_obj. Called with the + source as parameter. Mutually exclusive with fields_mapping. Defaults to None Raises: CircularReferenceError: Circular references in `source class` object are not allowed yet. @@ -320,10 +338,25 @@ def _map_common( if obj_id in _visited_stack: raise CircularReferenceError() + + target_cls_fields_mapping = None + if type(obj) in self._mappings: + _, target_cls_fields_mapping, a = self._mappings[type(obj)] + + if model_factory is not None and target_cls_fields_mapping: + raise ValueError( + "Cannot specify both model_factory and fields_mapping. " + "Use one of them to customize mapping." + ) + + if model_factory is not None and callable(model_factory): + return model_factory(obj) + _visited_stack.add(obj_id) target_cls_fields = self._get_fields(target_cls) + mapped_values: Dict[str, Any] = {} for field_name in target_cls_fields: value_found, value = _try_get_field_value(field_name, obj, custom_mapping) diff --git a/tests/test_model_factory_mapping.py b/tests/test_model_factory_mapping.py new file mode 100644 index 0000000..6c07c09 --- /dev/null +++ b/tests/test_model_factory_mapping.py @@ -0,0 +1,80 @@ +from decimal import Decimal +from enum import Enum +from typing import Union +from unittest import TestCase + +import pytest + +from automapper import create_mapper + + +class SourceEnum(Enum): + VALUE1 = "value1" + VALUE2 = "value2" + VALUE3 = "value3" + +class NameEnum(Enum): + VALUE1 = 1 + VALUE2 = 2 + VALUE3 = 3 + +class ValueEnum(Enum): + A = "value1" + B = "value2" + C = "value3" + +class ValueObject: + value: str + + def __init__(self, value: Union[float, int, Decimal]): + self.value = str(value) + + def __repr__(self): + return f"ValueObject(value={self.value})" + + def __str__(self): + return f"ValueObject(value={self.value})" + +class AutomapperModelFactoryTest(TestCase): + def setUp(self) -> None: + self.mapper = create_mapper() + + def test_map__with_registered_lambda_factory(self): + self.mapper.add(SourceEnum, NameEnum, model_factory=lambda x: NameEnum[x.name]) + self.mapper.add(ValueEnum, SourceEnum, model_factory=lambda x: SourceEnum(x.value)) + + self.assertEqual(self.mapper.map(SourceEnum.VALUE3), NameEnum.VALUE3) + self.assertEqual(self.mapper.map(ValueEnum.B), SourceEnum.VALUE2) + + + def test_map__with_lambda_factory(self): + name_enum = self.mapper.to(NameEnum).map(SourceEnum.VALUE3, model_factory=lambda x: NameEnum[x.name]) + value_enum = self.mapper.to(SourceEnum).map(ValueEnum.B, model_factory=lambda x: SourceEnum(x.value)) + + self.assertEqual(name_enum, NameEnum.VALUE3) + self.assertEqual(value_enum, SourceEnum.VALUE2) + + + def test_map__with_registered_constructor_factory(self): + self.mapper.add(Decimal, ValueObject, model_factory=ValueObject) # pyright: ignore[reportArgumentType] + + self.assertEqual(self.mapper.map(Decimal("42")).value, ValueObject(42).value) + + + def test_map__with_constructor_factory(self): + result = self.mapper.to(ValueObject).map(Decimal("42"), model_factory=ValueObject) # pyright: ignore[reportArgumentType] + + print(result) + self.assertEqual(result.value, ValueObject(42).value) + + + def test_map__with_factory_and_fields_mapping_raises_error(self): + self.mapper.add( + ValueEnum, + ValueObject, + model_factory=lambda s: ValueObject(int(s.value)), + fields_mapping={"value": lambda x: x.value} + ) + + with pytest.raises(ValueError): + self.mapper.map(ValueEnum.A)