Skip to content

Commit

Permalink
Add maybe_apply function (#223)
Browse files Browse the repository at this point in the history
* Add `maybe_apply` function

* Add apply functions

* Add typing tests

* Rename excs as exceptions

* Default value is the object, use UNSET

* Default not fallback

* Rename apply_catch

* Fix pytest version

---------

Co-authored-by: DeviousStoat <[email protected]>
  • Loading branch information
DeviousStoat and DeviousStoat authored Mar 4, 2024
1 parent 079ab36 commit 81b4a99
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 42 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ dev = [
"furo",
"invoke",
"mypy",
"pytest",
# pytest 8+ is not supported by pytest-mypy-testing
"pytest <8",
"pytest-mypy-testing",
"pytest-cov",
"sphinx",
Expand Down
11 changes: 9 additions & 2 deletions src/pydash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
zip_object_deep,
zip_with,
)
from .chaining import _Dash, chain, tap, thru
from .chaining import _Dash, chain, tap
from .collections import (
at,
count_by,
Expand Down Expand Up @@ -168,6 +168,10 @@
zscore,
)
from .objects import (
apply,
apply_catch,
apply_if,
apply_if_not_none,
assign,
assign_with,
callables,
Expand Down Expand Up @@ -462,7 +466,6 @@
"_Dash",
"chain",
"tap",
"thru",
"at",
"count_by",
"every",
Expand Down Expand Up @@ -544,6 +547,10 @@
"transpose",
"variance",
"zscore",
"apply",
"apply_catch",
"apply_if",
"apply_if_not_none",
"assign",
"assign_with",
"callables",
Expand Down
3 changes: 1 addition & 2 deletions src/pydash/chaining/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .chaining import _Dash, chain, tap, thru
from .chaining import _Dash, chain, tap


__all__ = (
"_Dash",
"chain",
"tap",
"thru",
)
33 changes: 30 additions & 3 deletions src/pydash/chaining/all_funcs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,6 @@ class AllFuncs:
def tap(self: "Chain[T]", interceptor: t.Callable[[T], t.Any]) -> "Chain[T]":
return self._wrap(pyd.tap)(interceptor)

def thru(self: "Chain[T]", interceptor: t.Callable[[T], T2]) -> "Chain[T2]":
return self._wrap(pyd.thru)(interceptor)

@t.overload
def at(self: "Chain[t.Mapping[T, T2]]", *paths: T) -> "Chain[t.List[t.Union[T2, None]]]": ...
@t.overload
Expand Down Expand Up @@ -2578,6 +2575,36 @@ class AllFuncs:
) -> "Chain[t.Any]":
return self._wrap(pyd.map_values_deep)(iteratee, property_path)

def apply(self: "Chain[T]", func: t.Callable[[T], T2]) -> "Chain[T2]":
return self._wrap(pyd.apply)(func)

def apply_if(
self: "Chain[T]", func: t.Callable[[T], T2], predicate: t.Callable[[T], bool]
) -> "Chain[t.Union[T, T2]]":
return self._wrap(pyd.apply_if)(func, predicate)

def apply_if_not_none(
self: "Chain[t.Optional[T]]", func: t.Callable[[T], T2]
) -> "Chain[t.Optional[T2]]":
return self._wrap(pyd.apply_if_not_none)(func)

@t.overload
def apply_catch(
self: "Chain[T]",
func: t.Callable[[T], T2],
exceptions: t.Iterable[t.Type[Exception]],
default: T3,
) -> "Chain[t.Union[T2, T3]]": ...
@t.overload
def apply_catch(
self: "Chain[T]",
func: t.Callable[[T], T2],
exceptions: t.Iterable[t.Type[Exception]],
default: Unset = UNSET,
) -> "Chain[t.Union[T, T2]]": ...
def apply_catch(self, func, exceptions, default=UNSET):
return self._wrap(pyd.apply_catch)(func, exceptions, default)

@t.overload
def merge(
self: "Chain[t.Mapping[T, T2]]", *sources: t.Mapping[T3, T4]
Expand Down
23 changes: 0 additions & 23 deletions src/pydash/chaining/chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
__all__ = (
"chain",
"tap",
"thru",
)

ValueT_co = t.TypeVar("ValueT_co", covariant=True)
Expand Down Expand Up @@ -263,25 +262,3 @@ def tap(value: T, interceptor: t.Callable[[T], t.Any]) -> T:
"""
interceptor(value)
return value


def thru(value: T, interceptor: t.Callable[[T], T2]) -> T2:
"""
Returns the result of calling `interceptor` on `value`. The purpose of this method is to pass
`value` through a function during a method chain.
Args:
value: Current value of chain operation.
interceptor: Function called with `value`.
Returns:
Results of ``interceptor(value)``.
Example:
>>> chain([1, 2, 3, 4]).thru(lambda x: x * 2).value()
[1, 2, 3, 4, 1, 2, 3, 4]
.. versionadded:: 2.0.0
"""
return interceptor(value)
121 changes: 120 additions & 1 deletion src/pydash/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pydash as pyd

from .helpers import UNSET, base_get, base_set, callit, getargcount, iterator, iteriteratee
from .helpers import UNSET, Unset, base_get, base_set, callit, getargcount, iterator, iteriteratee
from .types import IterateeObjT, PathT
from .utilities import PathToken, to_path, to_path_tokens

Expand All @@ -21,6 +21,10 @@
from _typeshed import SupportsRichComparisonT # pragma: no cover

__all__ = (
"apply",
"apply_catch",
"apply_if",
"apply_if_not_none",
"assign",
"assign_with",
"callables",
Expand Down Expand Up @@ -1251,6 +1255,121 @@ def deep_iteratee(value, key):
return callit(iteratee, obj, properties)


def apply(obj: T, func: t.Callable[[T], T2]) -> T2:
"""
Returns the result of calling `func` on `obj`. Particularly useful to pass
`obj` through a function during a method chain.
Args:
obj: Object to apply function to
func: Function called with `obj`.
Returns:
Results of ``func(value)``.
Example:
>>> apply(5, lambda x: x * 2)
10
.. versionadded:: 8.0.0
"""
return func(obj)


def apply_if(obj: T, func: t.Callable[[T], T2], predicate: t.Callable[[T], bool]) -> t.Union[T, T2]:
"""
Apply `func` to `obj` if `predicate` returns `True`.
Args:
obj: Object to apply `func` to.
func: Function to apply to `obj`.
predicate: Predicate applied to `obj`.
Returns:
Result of applying `func` to `obj` or `obj`.
Example:
>>> apply_if(2, lambda x: x * 2, lambda x: x > 1)
4
>>> apply_if(2, lambda x: x * 2, lambda x: x < 1)
2
.. versionadded:: 8.0.0
"""
return func(obj) if predicate(obj) else obj


def apply_if_not_none(obj: t.Optional[T], func: t.Callable[[T], T2]) -> t.Optional[T2]:
"""
Apply `func` to `obj` if `obj` is not ``None``.
Args:
obj: Object to apply `func` to.
func: Function to apply to `obj`.
Returns:
Result of applying `func` to `obj` or ``None``.
Example:
>>> apply_if_not_none(2, lambda x: x * 2)
4
>>> apply_if_not_none(None, lambda x: x * 2) is None
True
.. versionadded:: 8.0.0
"""
return apply_if(obj, func, lambda x: x is not None) # type: ignore


@t.overload
def apply_catch(
obj: T, func: t.Callable[[T], T2], exceptions: t.Iterable[t.Type[Exception]], default: T3
) -> t.Union[T2, T3]: ...


@t.overload
def apply_catch(
obj: T,
func: t.Callable[[T], T2],
exceptions: t.Iterable[t.Type[Exception]],
default: Unset = UNSET,
) -> t.Union[T, T2]: ...


def apply_catch(obj, func, exceptions, default=UNSET):
"""
Tries to apply `func` to `obj` if any of the exceptions in `excs` are raised, return `default`
or `obj` if not set.
Args:
obj: Object to apply `func` to.
func: Function to apply to `obj`.
excs: Exceptions to catch.
default: Value to return if exception is raised.
Returns:
Result of applying `func` to `obj` or ``default``.
Example:
>>> apply_catch(2, lambda x: x * 2, [ValueError])
4
>>> apply_catch(2, lambda x: x / 0, [ZeroDivisionError], "error")
'error'
>>> apply_catch(2, lambda x: x / 0, [ZeroDivisionError])
2
.. versionadded:: 8.0.0
"""
try:
return func(obj)
except tuple(exceptions):
return obj if default is UNSET else default


@t.overload
def merge(
obj: t.Mapping[T, T2], *sources: t.Mapping[T3, T4]
Expand Down
5 changes: 0 additions & 5 deletions tests/pytest_mypy_testing/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,3 @@ def log(value):
data.append(value)

reveal_type(_.chain([1, 2, 3, 4]).map(lambda x: x * 2).tap(log).value()) # R: builtins.list[builtins.int]


@pytest.mark.mypy_testing
def test_mypy_thru() -> None:
reveal_type(_.chain([1, 2, 3, 4]).thru(lambda x: x * 2).value()) # R: builtins.list[builtins.int]
25 changes: 25 additions & 0 deletions tests/pytest_mypy_testing/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,28 @@ def test_mypy_values() -> None:
reveal_type(_.values({'a': 1, 'b': 2, 'c': 3})) # R: builtins.list[builtins.int]
reveal_type(_.values([2, 4, 6, 8])) # R: builtins.list[builtins.int]
reveal_type(_.values(MyClass())) # R: builtins.list[Any]


@pytest.mark.mypy_testing
def test_mypy_apply() -> None:
reveal_type(_.apply("1", lambda x: int(x))) # R: builtins.int
reveal_type(_.apply(1, lambda x: x + 1)) # R: builtins.int
reveal_type(_.apply("hello", lambda x: x.upper())) # R: builtins.str


@pytest.mark.mypy_testing
def test_mypy_apply_if() -> None:
reveal_type(_.apply_if("5", lambda x: int(x), lambda x: x.isdecimal())) # R: Union[builtins.str, builtins.int]


@pytest.mark.mypy_testing
def test_mypy_apply_if_not_none() -> None:
reveal_type(_.apply_if_not_none(1, lambda x: x + 1)) # R: Union[builtins.int, None]
reveal_type(_.apply_if_not_none(None, lambda x: x + 1)) # R: Union[builtins.int, None]
reveal_type(_.apply_if_not_none("hello", lambda x: x.upper())) # R: Union[builtins.str, None]


@pytest.mark.mypy_testing
def test_mypy_apply_catch() -> None:
reveal_type(_.apply_catch(5, lambda x: x / 0, [ZeroDivisionError])) # R: Union[builtins.int, builtins.float]
reveal_type(_.apply_catch(5, lambda x: x / 0, [ZeroDivisionError], "error")) # R: Union[builtins.float, builtins.str]
5 changes: 0 additions & 5 deletions tests/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,3 @@ def test_chaining_value_to_string(case, expected):
def test_tap(value, interceptor, expected):
actual = _.chain(value).initial().tap(interceptor).last().value()
assert actual == expected


@parametrize("value,func,expected", [([1, 2, 3, 4, 5], lambda value: [sum(value)], 10)])
def test_thru(value, func, expected):
assert _.chain(value).initial().thru(func).last().value()
26 changes: 26 additions & 0 deletions tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,29 @@ def test_unset(obj, path, expected, new_obj):
@parametrize("case,expected", [({"a": 1, "b": 2, "c": 3}, [1, 2, 3]), ([1, 2, 3], [1, 2, 3])])
def test_values(case, expected):
assert set(_.values(case)) == set(expected)


@parametrize("case,expected", [((5, lambda x: x * 2), 10)])
def test_apply(case, expected):
assert _.apply(*case) == expected


@parametrize(
"case,expected",
[((5, lambda x: x * 2, lambda x: x == 5), 10), ((5, lambda x: x * 2, lambda x: x == 10), 5)],
)
def test_apply_if(case, expected):
assert _.apply_if(*case) == expected


@parametrize("case,expected", [((5, lambda x: x * 2), 10), ((None, lambda x: x * 2), None)])
def test_apply_if_not_none(case, expected):
assert _.apply_if_not_none(*case) == expected


@parametrize(
"case,expected",
[((5, lambda x: x * 2, [ValueError]), 10), ((5, lambda x: x / 0, [ZeroDivisionError]), 5)],
)
def test_apply_catch(case, expected):
assert _.apply_catch(*case) == expected

0 comments on commit 81b4a99

Please sign in to comment.