From 4081c27ec1335ceeaa91e6ae2bce80db4855723b Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 28 Jun 2025 16:45:10 -0400 Subject: [PATCH 1/4] allow args and kwargs in groupby.apply --- pandas-stubs/core/groupby/generic.pyi | 16 +++++++++++++--- tests/test_groupby.py | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index b07554de..d7641406 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -11,6 +11,7 @@ from typing import ( Generic, Literal, NamedTuple, + Protocol, TypeVar, final, overload, @@ -208,26 +209,35 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]): _TT = TypeVar("_TT", bound=Literal[True, False]) +class DFCallable1(Protocol): + def __call__(self, df: DataFrame, /, *args, **kwargs) -> Scalar | list | dict: ... + +class DFCallable2(Protocol): + def __call__(self, df: DataFrame, /, *args, **kwargs) -> DataFrame | Series: ... + +class DFCallable3(Protocol): + def __call__(self, df: Iterable, /, *args, **kwargs) -> float: ... + class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 @overload # type: ignore[override] def apply( self, - func: Callable[[DataFrame], Scalar | list | dict], + func: DFCallable1, *args, **kwargs, ) -> Series: ... @overload def apply( self, - func: Callable[[DataFrame], Series | DataFrame], + func: DFCallable2, *args, **kwargs, ) -> DataFrame: ... @overload def apply( # pyright: ignore[reportOverlappingOverload] self, - func: Callable[[Iterable], float], + func: DFCallable3, *args, **kwargs, ) -> DataFrame: ... diff --git a/tests/test_groupby.py b/tests/test_groupby.py index e6dbd0e1..90aa69f5 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -1102,3 +1102,23 @@ def test_dataframe_value_counts() -> None: Series, np.int64, ) + + +def test_dataframe_apply_kwargs() -> None: + # GH 1266 + df = DataFrame({"group": ["A", "A", "B", "B", "C"], "value": [10, 15, 10, 25, 30]}) + + def add_constant_to_mean(group: DataFrame, constant: int) -> DataFrame: + mean_val = group["value"].mean() + group["adjusted"] = mean_val + constant + return group + + check( + assert_type( + df.groupby("group", group_keys=False)[["group", "value"]].apply( + add_constant_to_mean, constant=5 + ), + DataFrame, + ), + DataFrame, + ) From d13dfe29b2b5df8596af0d2faaae461e4600a9d2 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 29 Jun 2025 20:25:25 -0400 Subject: [PATCH 2/4] use paramspec on args, kwargs --- pandas-stubs/core/groupby/generic.pyi | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index d7641406..83520863 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -209,14 +209,18 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]): _TT = TypeVar("_TT", bound=Literal[True, False]) -class DFCallable1(Protocol): - def __call__(self, df: DataFrame, /, *args, **kwargs) -> Scalar | list | dict: ... +class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type] + def __call__( + self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs + ) -> Scalar | list | dict: ... -class DFCallable2(Protocol): - def __call__(self, df: DataFrame, /, *args, **kwargs) -> DataFrame | Series: ... +class DFCallable2(Protocol[P]): # ty: ignore[invalid-argument-type] + def __call__( + self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs + ) -> DataFrame | Series: ... -class DFCallable3(Protocol): - def __call__(self, df: Iterable, /, *args, **kwargs) -> float: ... +class DFCallable3(Protocol[P]): # ty: ignore[invalid-argument-type] + def __call__(self, df: Iterable, /, *args: P.args, **kwargs: P.kwargs) -> float: ... class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 From 94a665ce5818b5c7cbc4a747cda5326b1c1fe919 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 29 Jun 2025 20:34:47 -0400 Subject: [PATCH 3/4] remove include_groups. use P.args and P.kwargs in apply defs --- pandas-stubs/core/groupby/generic.pyi | 24 ++++++++++++++---------- tests/test_groupby.py | 3 +-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 83520863..2e4864cd 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -209,6 +209,7 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]): _TT = TypeVar("_TT", bound=Literal[True, False]) +# ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945 class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type] def __call__( self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs @@ -227,23 +228,26 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): @overload # type: ignore[override] def apply( self, - func: DFCallable1, - *args, - **kwargs, + func: DFCallable1[P], + /, + *args: P.args, + **kwargs: P.kwargs, ) -> Series: ... @overload def apply( self, - func: DFCallable2, - *args, - **kwargs, + func: DFCallable2[P], + /, + *args: P.args, + **kwargs: P.kwargs, ) -> DataFrame: ... @overload - def apply( # pyright: ignore[reportOverlappingOverload] + def apply( self, - func: DFCallable3, - *args, - **kwargs, + func: DFCallable3[P], + /, + *args: P.args, + **kwargs: P.kwargs, ) -> DataFrame: ... # error: overload 1 overlaps overload 2 because of different return types @overload diff --git a/tests/test_groupby.py b/tests/test_groupby.py index 90aa69f5..93ae9d76 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -273,7 +273,7 @@ def resample_interpolate(x: DataFrame) -> DataFrame: check( assert_type( - GB_DF.apply(resample_interpolate, include_groups=False), + GB_DF.apply(resample_interpolate), DataFrame, ), DataFrame, @@ -286,7 +286,6 @@ def resample_interpolate_linear(x: DataFrame) -> DataFrame: assert_type( GB_DF.apply( resample_interpolate_linear, - include_groups=False, ), DataFrame, ), From ddebefa27c3a38bf45977cf43ec59dec8e4b7f1e Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 30 Jun 2025 11:20:54 -0400 Subject: [PATCH 4/4] add test to see if failures are picked up --- tests/test_groupby.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_groupby.py b/tests/test_groupby.py index 93ae9d76..2da00450 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -1121,3 +1121,8 @@ def add_constant_to_mean(group: DataFrame, constant: int) -> DataFrame: ), DataFrame, ) + if TYPE_CHECKING_INVALID_USAGE: + df.groupby("group", group_keys=False)[["group", "value"]].apply( + add_constant_to_mean, + constant="5", # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType] + )