Skip to content

feat: Allow factory method when defining a mapping #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class Address:
number: int
zip_code: int
city: str

class PersonInfo:
def __init__(self, name: str, age: int, address: Address):
self.name = name
Expand Down Expand Up @@ -181,6 +181,43 @@ print("Target public_info.address is same as source address: ", address is publi
* [TortoiseORM](https://github.com/tortoise/tortoise-orm)
* [SQLAlchemy](https://www.sqlalchemy.org/)

## Complex mapping support

Support to pass a model factory method or cunstructor when registering a mapping.
This allows for mapping to value objects or other complex types.

```python
class SourceEnum(Enum):
VALUE1 = "value1"
VALUE2 = "value2"
VALUE3 = "value3"

class NameEnum(Enum):
VALUE1 = 1
VALUE2 = 2
VALUE3 = 3

class ValueEnum(Enum):
A = "value1"
B = "value2"
C = "value3"


class ValueObject:
value: str

def __init__(self, value: Union[float, int, Decimal]):
self.value = str(value)

mapper.add(SourceEnum, NameEnum, model_factory=lambda x: NameEnum[x.name])
mapper.map(SourceEnum.VALUE1) # NameEnum.VALUE1

mapper.add(ValueEnum, SourceEnum, model_factory=lambda x: SourceEnum(x.value))
mapper.map(ValueEnum.B) # SourceEnum.VALUE2

mapper.to(ValueObject).map(Decimal("42"), model_factory=ValueObject) # ValueObject(42)
```

## Pydantic/FastAPI Support
Out of the box Pydantic models support:
```python
Expand Down Expand Up @@ -273,7 +310,7 @@ class PublicUserInfo(Base):
id = Column(Integer, primary_key=True)
public_name = Column(String)
hobbies = Column(String)

obj = UserInfo(
id=2,
full_name="Danny DeVito",
Expand Down Expand Up @@ -304,7 +341,7 @@ class TargetClass:
def __init__(self, **kwargs):
self.name = kwargs["name"]
self.age = kwargs["age"]

@staticmethod
def get_fields(cls):
return ["name", "age"]
Expand Down Expand Up @@ -358,7 +395,7 @@ T = TypeVar("T")

def class_has_fields_property(target_cls: Type[T]) -> bool:
return callable(getattr(target_cls, "fields", None))

mapper.add_spec(class_has_fields_property, lambda t: getattr(t, "fields")())

target_obj = mapper.to(TargetClass).map(source_obj)
Expand Down
43 changes: 38 additions & 5 deletions automapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def map(
skip_none_values: bool = False,
fields_mapping: FieldsMap = None,
use_deepcopy: bool = True,
model_factory: Optional[Callable[[S], T]] = None,
) -> T:
"""Produces output object mapped from source object and custom arguments.

Expand All @@ -72,6 +73,9 @@ def map(
Specify dictionary in format {"field_name": value_object}. Defaults to None.
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
Defaults to True.
model_factory (Callable, optional): Custom factory funtion
factory function that is used to create the target_obj. Called with the
source as parameter. Mutually exclusive with fields_mapping. Defaults to None

Raises:
CircularReferenceError: Circular references in `source class` object are not allowed yet.
Expand All @@ -86,13 +90,14 @@ def map(
skip_none_values=skip_none_values,
custom_mapping=fields_mapping,
use_deepcopy=use_deepcopy,
model_factory=model_factory,
)


class Mapper:
def __init__(self) -> None:
"""Initializes internal containers"""
self._mappings: Dict[Type[S], Tuple[T, FieldsMap]] = {} # type: ignore [valid-type]
self._mappings: Dict[Type[S], Tuple[T, FieldsMap, Callable[[S], [T]]]] = {} # type: ignore [valid-type]
self._class_specs: Dict[Type[T], SpecFunction[T]] = {} # type: ignore [valid-type]
self._classifier_specs: Dict[ # type: ignore [valid-type]
ClassifierFunction[T], SpecFunction[T]
Expand Down Expand Up @@ -147,6 +152,7 @@ def add(
target_cls: Type[T],
override: bool = False,
fields_mapping: FieldsMap = None,
model_factory: Optional[Callable[[S], T]] = None,
) -> None:
"""Adds mapping between object of `source class` to an object of `target class`.

Expand All @@ -156,7 +162,12 @@ def add(
override (bool, optional): Override existing `source class` mapping to use new `target class`.
Defaults to False.
fields_mapping (FieldsMap, optional): Custom mapping.
Specify dictionary in format {"field_name": value_object}. Defaults to None.
Specify dictionary in format {"field_name": value_objecture_obj}.
Can take a lamdba funtion as argument, that will get the source_cls
as argument. Defaults to None.
model_factory (Callable, optional): Custom factory funtion
factory function that is used to create the target_obj. Called with the
source as parameter. Mutually exclusive with fields_mapping. Defaults to None

Raises:
DuplicatedRegistrationError: Same mapping for `source class` was added.
Expand All @@ -168,7 +179,7 @@ def add(
raise DuplicatedRegistrationError(
f"source_cls {source_cls} was already added for mapping"
)
self._mappings[source_cls] = (target_cls, fields_mapping)
self._mappings[source_cls] = (target_cls, fields_mapping, model_factory)

def map(
self,
Expand Down Expand Up @@ -201,7 +212,7 @@ def map(
raise MappingError(f"Missing mapping type for input type {obj_type}")
obj_type_prefix = f"{obj_type.__name__}."

target_cls, target_cls_field_mappings = self._mappings[obj_type]
target_cls, target_cls_field_mappings, target_cls_model_factory= self._mappings[obj_type]

common_fields_mapping = fields_mapping
if target_cls_field_mappings:
Expand All @@ -221,13 +232,16 @@ def map(
**fields_mapping,
} # merge two dict into one, fields_mapping has priority



return self._map_common(
obj,
target_cls,
set(),
skip_none_values=skip_none_values,
custom_mapping=common_fields_mapping,
use_deepcopy=use_deepcopy,
model_factory=target_cls_model_factory,
)

def _get_fields(self, target_cls: Type[T]) -> Iterable[str]:
Expand Down Expand Up @@ -257,7 +271,7 @@ def _map_subobject(
raise CircularReferenceError()

if type(obj) in self._mappings:
target_cls, _ = self._mappings[type(obj)]
target_cls, _, _ = self._mappings[type(obj)]
result: Any = self._map_common(
obj, target_cls, _visited_stack, skip_none_values=skip_none_values
)
Expand Down Expand Up @@ -297,6 +311,7 @@ def _map_common(
skip_none_values: bool = False,
custom_mapping: FieldsMap = None,
use_deepcopy: bool = True,
model_factory: Optional[Callable[[S], T]] = None,
) -> T:
"""Produces output object mapped from source object and custom arguments.

Expand All @@ -309,6 +324,9 @@ def _map_common(
Specify dictionary in format {"field_name": value_object}. Defaults to None.
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
Defaults to True.
model_factory (Callable, optional): Custom factory funtion
factory function that is used to create the target_obj. Called with the
source as parameter. Mutually exclusive with fields_mapping. Defaults to None

Raises:
CircularReferenceError: Circular references in `source class` object are not allowed yet.
Expand All @@ -320,10 +338,25 @@ def _map_common(

if obj_id in _visited_stack:
raise CircularReferenceError()

target_cls_fields_mapping = None
if type(obj) in self._mappings:
_, target_cls_fields_mapping, a = self._mappings[type(obj)]

if model_factory is not None and target_cls_fields_mapping:
raise ValueError(
"Cannot specify both model_factory and fields_mapping. "
"Use one of them to customize mapping."
)

if model_factory is not None and callable(model_factory):
return model_factory(obj)

_visited_stack.add(obj_id)

target_cls_fields = self._get_fields(target_cls)


mapped_values: Dict[str, Any] = {}
for field_name in target_cls_fields:
value_found, value = _try_get_field_value(field_name, obj, custom_mapping)
Expand Down
80 changes: 80 additions & 0 deletions tests/test_model_factory_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from decimal import Decimal
from enum import Enum
from typing import Union
from unittest import TestCase

import pytest

from automapper import create_mapper


class SourceEnum(Enum):
VALUE1 = "value1"
VALUE2 = "value2"
VALUE3 = "value3"

class NameEnum(Enum):
VALUE1 = 1
VALUE2 = 2
VALUE3 = 3

class ValueEnum(Enum):
A = "value1"
B = "value2"
C = "value3"

class ValueObject:
value: str

def __init__(self, value: Union[float, int, Decimal]):
self.value = str(value)

def __repr__(self):
return f"ValueObject(value={self.value})"

def __str__(self):
return f"ValueObject(value={self.value})"

class AutomapperModelFactoryTest(TestCase):
def setUp(self) -> None:
self.mapper = create_mapper()

def test_map__with_registered_lambda_factory(self):
self.mapper.add(SourceEnum, NameEnum, model_factory=lambda x: NameEnum[x.name])
self.mapper.add(ValueEnum, SourceEnum, model_factory=lambda x: SourceEnum(x.value))

self.assertEqual(self.mapper.map(SourceEnum.VALUE3), NameEnum.VALUE3)
self.assertEqual(self.mapper.map(ValueEnum.B), SourceEnum.VALUE2)


def test_map__with_lambda_factory(self):
name_enum = self.mapper.to(NameEnum).map(SourceEnum.VALUE3, model_factory=lambda x: NameEnum[x.name])
value_enum = self.mapper.to(SourceEnum).map(ValueEnum.B, model_factory=lambda x: SourceEnum(x.value))

self.assertEqual(name_enum, NameEnum.VALUE3)
self.assertEqual(value_enum, SourceEnum.VALUE2)


def test_map__with_registered_constructor_factory(self):
self.mapper.add(Decimal, ValueObject, model_factory=ValueObject) # pyright: ignore[reportArgumentType]

self.assertEqual(self.mapper.map(Decimal("42")).value, ValueObject(42).value)


def test_map__with_constructor_factory(self):
result = self.mapper.to(ValueObject).map(Decimal("42"), model_factory=ValueObject) # pyright: ignore[reportArgumentType]

print(result)
self.assertEqual(result.value, ValueObject(42).value)


def test_map__with_factory_and_fields_mapping_raises_error(self):
self.mapper.add(
ValueEnum,
ValueObject,
model_factory=lambda s: ValueObject(int(s.value)),
fields_mapping={"value": lambda x: x.value}
)

with pytest.raises(ValueError):
self.mapper.map(ValueEnum.A)