Skip to content

Commit a2315af

Browse files
authored
BUG: Fix pivot_table margins to include NaN groups when dropna=False (#61524)
1 parent 4f2aa4d commit a2315af

File tree

4 files changed

+52
-12
lines changed

4 files changed

+52
-12
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ Reshaping
869869
- Bug in :meth:`DataFrame.merge` when merging two :class:`DataFrame` on ``intc`` or ``uintc`` types on Windows (:issue:`60091`, :issue:`58713`)
870870
- Bug in :meth:`DataFrame.pivot_table` incorrectly subaggregating results when called without an ``index`` argument (:issue:`58722`)
871871
- Bug in :meth:`DataFrame.pivot_table` incorrectly ignoring the ``values`` argument when also supplied to the ``index`` or ``columns`` parameters (:issue:`57876`, :issue:`61292`)
872+
- Bug in :meth:`DataFrame.pivot_table` where ``margins=True`` did not correctly include groups with ``NaN`` values in the index or columns when ``dropna=False`` was explicitly passed. (:issue:`61509`)
872873
- Bug in :meth:`DataFrame.stack` with the new implementation where ``ValueError`` is raised when ``level=[]`` (:issue:`60740`)
873874
- Bug in :meth:`DataFrame.unstack` producing incorrect results when manipulating empty :class:`DataFrame` with an :class:`ExtentionDtype` (:issue:`59123`)
874875
- Bug in :meth:`concat` where concatenating DataFrame and Series with ``ignore_index = True`` drops the series name (:issue:`60723`, :issue:`56257`)

pandas/core/reshape/pivot.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def __internal_pivot_table(
396396
observed=dropna,
397397
margins_name=margins_name,
398398
fill_value=fill_value,
399+
dropna=dropna,
399400
)
400401

401402
# discard the top level
@@ -422,6 +423,7 @@ def _add_margins(
422423
observed: bool,
423424
margins_name: Hashable = "All",
424425
fill_value=None,
426+
dropna: bool = True,
425427
):
426428
if not isinstance(margins_name, str):
427429
raise ValueError("margins_name argument must be a string")
@@ -461,6 +463,7 @@ def _add_margins(
461463
kwargs,
462464
observed,
463465
margins_name,
466+
dropna,
464467
)
465468
if not isinstance(marginal_result_set, tuple):
466469
return marginal_result_set
@@ -469,7 +472,7 @@ def _add_margins(
469472
# no values, and table is a DataFrame
470473
assert isinstance(table, ABCDataFrame)
471474
marginal_result_set = _generate_marginal_results_without_values(
472-
table, data, rows, cols, aggfunc, kwargs, observed, margins_name
475+
table, data, rows, cols, aggfunc, kwargs, observed, margins_name, dropna
473476
)
474477
if not isinstance(marginal_result_set, tuple):
475478
return marginal_result_set
@@ -538,6 +541,7 @@ def _generate_marginal_results(
538541
kwargs,
539542
observed: bool,
540543
margins_name: Hashable = "All",
544+
dropna: bool = True,
541545
):
542546
margin_keys: list | Index
543547
if len(cols) > 0:
@@ -551,7 +555,7 @@ def _all_key(key):
551555
if len(rows) > 0:
552556
margin = (
553557
data[rows + values]
554-
.groupby(rows, observed=observed)
558+
.groupby(rows, observed=observed, dropna=dropna)
555559
.agg(aggfunc, **kwargs)
556560
)
557561
cat_axis = 1
@@ -567,7 +571,7 @@ def _all_key(key):
567571
else:
568572
margin = (
569573
data[cols[:1] + values]
570-
.groupby(cols[:1], observed=observed)
574+
.groupby(cols[:1], observed=observed, dropna=dropna)
571575
.agg(aggfunc, **kwargs)
572576
.T
573577
)
@@ -610,7 +614,9 @@ def _all_key(key):
610614

611615
if len(cols) > 0:
612616
row_margin = (
613-
data[cols + values].groupby(cols, observed=observed).agg(aggfunc, **kwargs)
617+
data[cols + values]
618+
.groupby(cols, observed=observed, dropna=dropna)
619+
.agg(aggfunc, **kwargs)
614620
)
615621
row_margin = row_margin.stack()
616622

@@ -633,6 +639,7 @@ def _generate_marginal_results_without_values(
633639
kwargs,
634640
observed: bool,
635641
margins_name: Hashable = "All",
642+
dropna: bool = True,
636643
):
637644
margin_keys: list | Index
638645
if len(cols) > 0:
@@ -645,7 +652,7 @@ def _all_key():
645652
return (margins_name,) + ("",) * (len(cols) - 1)
646653

647654
if len(rows) > 0:
648-
margin = data.groupby(rows, observed=observed)[rows].apply(
655+
margin = data.groupby(rows, observed=observed, dropna=dropna)[rows].apply(
649656
aggfunc, **kwargs
650657
)
651658
all_key = _all_key()
@@ -654,7 +661,9 @@ def _all_key():
654661
margin_keys.append(all_key)
655662

656663
else:
657-
margin = data.groupby(level=0, observed=observed).apply(aggfunc, **kwargs)
664+
margin = data.groupby(level=0, observed=observed, dropna=dropna).apply(
665+
aggfunc, **kwargs
666+
)
658667
all_key = _all_key()
659668
table[all_key] = margin
660669
result = table
@@ -665,7 +674,7 @@ def _all_key():
665674
margin_keys = table.columns
666675

667676
if len(cols):
668-
row_margin = data.groupby(cols, observed=observed)[cols].apply(
677+
row_margin = data.groupby(cols, observed=observed, dropna=dropna)[cols].apply(
669678
aggfunc, **kwargs
670679
)
671680
else:

pandas/tests/reshape/test_crosstab.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def test_margin_dropna4(self):
289289
# GH: 10772: Keep np.nan in result with dropna=False
290290
df = DataFrame({"a": [1, 2, 2, 2, 2, np.nan], "b": [3, 3, 4, 4, 4, 4]})
291291
actual = crosstab(df.a, df.b, margins=True, dropna=False)
292-
expected = DataFrame([[1, 0, 1.0], [1, 3, 4.0], [0, 1, np.nan], [2, 4, 6.0]])
292+
expected = DataFrame([[1, 0, 1], [1, 3, 4], [0, 1, 1], [2, 4, 6]])
293293
expected.index = Index([1.0, 2.0, np.nan, "All"], name="a")
294294
expected.columns = Index([3, 4, "All"], name="b")
295295
tm.assert_frame_equal(actual, expected)
@@ -301,11 +301,11 @@ def test_margin_dropna5(self):
301301
)
302302
actual = crosstab(df.a, df.b, margins=True, dropna=False)
303303
expected = DataFrame(
304-
[[1, 0, 0, 1.0], [0, 1, 0, 1.0], [0, 3, 1, np.nan], [1, 4, 0, 6.0]]
304+
[[1, 0, 0, 1.0], [0, 1, 0, 1.0], [0, 3, 1, 4.0], [1, 4, 1, 6.0]]
305305
)
306306
expected.index = Index([1.0, 2.0, np.nan, "All"], name="a")
307307
expected.columns = Index([3.0, 4.0, np.nan, "All"], name="b")
308-
tm.assert_frame_equal(actual, expected)
308+
tm.assert_frame_equal(actual, expected, check_dtype=False)
309309

310310
def test_margin_dropna6(self):
311311
# GH: 10772: Keep np.nan in result with dropna=False
@@ -326,7 +326,7 @@ def test_margin_dropna6(self):
326326
names=["b", "c"],
327327
)
328328
expected = DataFrame(
329-
[[1, 0, 1, 0, 0, 0, 2], [2, 0, 1, 1, 0, 1, 5], [3, 0, 2, 1, 0, 0, 7]],
329+
[[1, 0, 1, 0, 0, 0, 2], [2, 0, 1, 1, 0, 1, 5], [3, 0, 2, 1, 0, 1, 7]],
330330
columns=m,
331331
)
332332
expected.index = Index(["bar", "foo", "All"], name="a")
@@ -349,7 +349,7 @@ def test_margin_dropna6(self):
349349
[0, 0, np.nan],
350350
[2, 0, 2.0],
351351
[1, 1, 2.0],
352-
[0, 1, np.nan],
352+
[0, 1, 1.0],
353353
[5, 2, 7.0],
354354
],
355355
index=m,

pandas/tests/reshape/test_pivot.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,6 +2585,36 @@ def test_pivot_table_values_as_two_params(
25852585
expected = DataFrame(data=e_data, index=e_index, columns=e_cols)
25862586
tm.assert_frame_equal(result, expected)
25872587

2588+
def test_pivot_table_margins_include_nan_groups(self):
2589+
# GH#61509
2590+
df = DataFrame(
2591+
{
2592+
"i": [1, 2, 3],
2593+
"g1": ["a", "b", "b"],
2594+
"g2": ["x", None, None],
2595+
}
2596+
)
2597+
2598+
result = df.pivot_table(
2599+
index="g1",
2600+
columns="g2",
2601+
values="i",
2602+
aggfunc="count",
2603+
dropna=False,
2604+
margins=True,
2605+
)
2606+
2607+
expected = DataFrame(
2608+
{
2609+
"x": {"a": 1.0, "b": np.nan, "All": 1.0},
2610+
np.nan: {"a": np.nan, "b": 2.0, "All": 2.0},
2611+
"All": {"a": 1.0, "b": 2.0, "All": 3.0},
2612+
}
2613+
)
2614+
expected.index.name = "g1"
2615+
expected.columns.name = "g2"
2616+
tm.assert_frame_equal(result, expected, check_dtype=False)
2617+
25882618

25892619
class TestPivot:
25902620
def test_pivot(self):

0 commit comments

Comments
 (0)