Skip to content

Commit a32df9c

Browse files
authored
Fix bug for NestedFrame.drop to support 'columns=' (#407)
* update drop support via 'columns=' and add tests * delete comment * remove redundant logic * fix docstring link
1 parent ebe6315 commit a32df9c

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,8 @@ def drop(
727727
Remove rows or columns by specifying label names and corresponding
728728
axis, or by directly specifying index or column names. When using a
729729
multi-index, labels on different levels can be removed by
730-
specifying the level. See the user guide for more information about
730+
specifying the level. See the `user guide <https://pandas.pydata.org/docs/user_guide
731+
/advanced.html#advanced-shown-levels>`_ for more information about
731732
the now unused levels.
732733
733734
Parameters
@@ -777,10 +778,14 @@ def drop(
777778
"""
778779

779780
# axis 1 requires special handling for nested columns
780-
if axis == 1:
781+
if axis == 1 or columns is not None:
781782
# label convergence
782783
if isinstance(labels, str):
783784
labels = [labels]
785+
elif columns is not None:
786+
labels = [columns] if isinstance(columns, str) else columns
787+
columns = None
788+
axis = 1
784789
nested_labels = [label for label in labels if self._is_known_hierarchical_column(label)]
785790
base_labels = [label for label in labels if not self._is_known_hierarchical_column(label)]
786791

tests/nested_pandas/nestedframe/test_nestedframe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,15 @@ def test_drop():
14341434
with pytest.raises(KeyError):
14351435
base.drop(["a", "nested.not_a_field"], axis=1)
14361436

1437+
# Test dropping nested column using 'columns='
1438+
dropped_cols = base.drop(columns="nested.c")
1439+
assert "c" not in dropped_cols.nested.nest.columns
1440+
1441+
# Test dropping multiple nested columns using 'columns='
1442+
dropped_multcols = base.drop(columns=["nested.c", "nested2.f"])
1443+
assert "c" not in dropped_multcols.nested.nest.columns
1444+
assert "f" not in dropped_multcols.nested2.nest.columns
1445+
14371446

14381447
def test_min():
14391448
"""Test min function return correct result with and without the nested columns"""

0 commit comments

Comments
 (0)