Skip to content

Commit 824065c

Browse files
authored
allow args and kwargs in groupby.apply (#1268)
* allow args and kwargs in groupby.apply * use paramspec on args, kwargs * remove include_groups. use P.args and P.kwargs in apply defs * add test to see if failures are picked up
1 parent 08c673c commit 824065c

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing import (
1111
Generic,
1212
Literal,
1313
NamedTuple,
14+
Protocol,
1415
TypeVar,
1516
final,
1617
overload,
@@ -208,28 +209,45 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
208209

209210
_TT = TypeVar("_TT", bound=Literal[True, False])
210211

212+
# ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945
213+
class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type]
214+
def __call__(
215+
self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs
216+
) -> Scalar | list | dict: ...
217+
218+
class DFCallable2(Protocol[P]): # ty: ignore[invalid-argument-type]
219+
def __call__(
220+
self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs
221+
) -> DataFrame | Series: ...
222+
223+
class DFCallable3(Protocol[P]): # ty: ignore[invalid-argument-type]
224+
def __call__(self, df: Iterable, /, *args: P.args, **kwargs: P.kwargs) -> float: ...
225+
211226
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
212227
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
213228
@overload # type: ignore[override]
214229
def apply(
215230
self,
216-
func: Callable[[DataFrame], Scalar | list | dict],
217-
*args,
218-
**kwargs,
231+
func: DFCallable1[P],
232+
/,
233+
*args: P.args,
234+
**kwargs: P.kwargs,
219235
) -> Series: ...
220236
@overload
221237
def apply(
222238
self,
223-
func: Callable[[DataFrame], Series | DataFrame],
224-
*args,
225-
**kwargs,
239+
func: DFCallable2[P],
240+
/,
241+
*args: P.args,
242+
**kwargs: P.kwargs,
226243
) -> DataFrame: ...
227244
@overload
228-
def apply( # pyright: ignore[reportOverlappingOverload]
245+
def apply(
229246
self,
230-
func: Callable[[Iterable], float],
231-
*args,
232-
**kwargs,
247+
func: DFCallable3[P],
248+
/,
249+
*args: P.args,
250+
**kwargs: P.kwargs,
233251
) -> DataFrame: ...
234252
# error: overload 1 overlaps overload 2 because of different return types
235253
@overload

tests/test_groupby.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def resample_interpolate(x: DataFrame) -> DataFrame:
271271

272272
check(
273273
assert_type(
274-
GB_DF.apply(resample_interpolate, include_groups=False),
274+
GB_DF.apply(resample_interpolate),
275275
DataFrame,
276276
),
277277
DataFrame,
@@ -284,7 +284,6 @@ def resample_interpolate_linear(x: DataFrame) -> DataFrame:
284284
assert_type(
285285
GB_DF.apply(
286286
resample_interpolate_linear,
287-
include_groups=False,
288287
),
289288
DataFrame,
290289
),
@@ -1099,3 +1098,28 @@ def test_dataframe_value_counts() -> None:
10991098
Series,
11001099
np.int64,
11011100
)
1101+
1102+
1103+
def test_dataframe_apply_kwargs() -> None:
1104+
# GH 1266
1105+
df = DataFrame({"group": ["A", "A", "B", "B", "C"], "value": [10, 15, 10, 25, 30]})
1106+
1107+
def add_constant_to_mean(group: DataFrame, constant: int) -> DataFrame:
1108+
mean_val = group["value"].mean()
1109+
group["adjusted"] = mean_val + constant
1110+
return group
1111+
1112+
check(
1113+
assert_type(
1114+
df.groupby("group", group_keys=False)[["group", "value"]].apply(
1115+
add_constant_to_mean, constant=5
1116+
),
1117+
DataFrame,
1118+
),
1119+
DataFrame,
1120+
)
1121+
if TYPE_CHECKING_INVALID_USAGE:
1122+
df.groupby("group", group_keys=False)[["group", "value"]].apply(
1123+
add_constant_to_mean,
1124+
constant="5", # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
1125+
)

0 commit comments

Comments
 (0)