diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index b07554de..2e4864cd 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -11,6 +11,7 @@ from typing import ( Generic, Literal, NamedTuple, + Protocol, TypeVar, final, overload, @@ -208,28 +209,45 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]): _TT = TypeVar("_TT", bound=Literal[True, False]) +# ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945 +class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type] + def __call__( + self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs + ) -> Scalar | list | dict: ... + +class DFCallable2(Protocol[P]): # ty: ignore[invalid-argument-type] + def __call__( + self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs + ) -> DataFrame | Series: ... + +class DFCallable3(Protocol[P]): # ty: ignore[invalid-argument-type] + def __call__(self, df: Iterable, /, *args: P.args, **kwargs: P.kwargs) -> float: ... + class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 @overload # type: ignore[override] def apply( self, - func: Callable[[DataFrame], Scalar | list | dict], - *args, - **kwargs, + func: DFCallable1[P], + /, + *args: P.args, + **kwargs: P.kwargs, ) -> Series: ... @overload def apply( self, - func: Callable[[DataFrame], Series | DataFrame], - *args, - **kwargs, + func: DFCallable2[P], + /, + *args: P.args, + **kwargs: P.kwargs, ) -> DataFrame: ... @overload - def apply( # pyright: ignore[reportOverlappingOverload] + def apply( self, - func: Callable[[Iterable], float], - *args, - **kwargs, + func: DFCallable3[P], + /, + *args: P.args, + **kwargs: P.kwargs, ) -> DataFrame: ... # error: overload 1 overlaps overload 2 because of different return types @overload diff --git a/tests/test_groupby.py b/tests/test_groupby.py index e6dbd0e1..2da00450 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -273,7 +273,7 @@ def resample_interpolate(x: DataFrame) -> DataFrame: check( assert_type( - GB_DF.apply(resample_interpolate, include_groups=False), + GB_DF.apply(resample_interpolate), DataFrame, ), DataFrame, @@ -286,7 +286,6 @@ def resample_interpolate_linear(x: DataFrame) -> DataFrame: assert_type( GB_DF.apply( resample_interpolate_linear, - include_groups=False, ), DataFrame, ), @@ -1102,3 +1101,28 @@ def test_dataframe_value_counts() -> None: Series, np.int64, ) + + +def test_dataframe_apply_kwargs() -> None: + # GH 1266 + df = DataFrame({"group": ["A", "A", "B", "B", "C"], "value": [10, 15, 10, 25, 30]}) + + def add_constant_to_mean(group: DataFrame, constant: int) -> DataFrame: + mean_val = group["value"].mean() + group["adjusted"] = mean_val + constant + return group + + check( + assert_type( + df.groupby("group", group_keys=False)[["group", "value"]].apply( + add_constant_to_mean, constant=5 + ), + DataFrame, + ), + DataFrame, + ) + if TYPE_CHECKING_INVALID_USAGE: + df.groupby("group", group_keys=False)[["group", "value"]].apply( + add_constant_to_mean, + constant="5", # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType] + )