Skip to content

Commit 1fde41d

Browse files
committed
more tests
1 parent cff8d0b commit 1fde41d

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

src/nested_pandas/series/accessor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from numpy.typing import ArrayLike
1111
from pandas.api.extensions import register_series_accessor
1212

13-
from nested_pandas.nestedframe.core import NestedFrame
1413
from nested_pandas.series.dtype import NestedDtype
1514
from nested_pandas.series.nestedseries import NestedSeries
1615
from nested_pandas.series.packer import pack_flat, pack_sorted_df_into_struct
@@ -498,7 +497,12 @@ def __getitem__(self, key: str | list[str]) -> NestedSeries:
498497
if not key.index.equals(flat_df.index):
499498
raise ValueError("Boolean mask must have the same index as the flattened nested dataframe.")
500499
# Apply the mask to the series, return a new NestedFrame
501-
return NestedFrame(index=self._series.index).add_nested(flat_df[key], name=self._series.name)
500+
# return NestedFrame(index=self._series.index).add_nested(flat_df[key], name=self._series.name)
501+
return NestedSeries(
502+
pack_flat(flat_df[key]),
503+
index=self._series.index,
504+
name=self._series.name,
505+
)
502506

503507
# A list of fields may return a pd.Series or a NestedSeries depending
504508
# on the number of fields requested and their dtypes

tests/nested_pandas/series/test_nestedseries.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,94 @@ def test_nestedseries_list_lengths():
8484
assert list(series.list_lengths) == [2, 2]
8585

8686

87+
def test_nestedseries_getitem_single_field():
88+
"""Test getitem for a single field in NestedSeries."""
89+
series = NestedSeries(
90+
data=[
91+
(np.array([1, 2]), np.array([0, 1])),
92+
(np.array([3, 4]), np.array([0, 1])),
93+
],
94+
index=[0, 1],
95+
dtype=NestedDtype(pa.struct([("a", pa.list_(pa.int64())), ("b", pa.list_(pa.int64()))])),
96+
)
97+
98+
result = series["a"]
99+
expected = pd.Series([1, 2, 3, 4], index=[0, 0, 1, 1], dtype=pd.ArrowDtype(pa.int64()), name="a")
100+
pd.testing.assert_series_equal(result, expected)
101+
102+
103+
def test_nestedseries_getitem_multiple_fields():
104+
"""Test getitem for multiple fields in NestedSeries."""
105+
series = NestedSeries(
106+
data=[
107+
(np.array([1, 2]), np.array([0, 1])),
108+
(np.array([3, 4]), np.array([0, 1])),
109+
],
110+
index=[0, 1],
111+
dtype=NestedDtype(pa.struct([("a", pa.list_(pa.int64())), ("b", pa.list_(pa.int64()))])),
112+
)
113+
114+
result = series[["a", "b"]]
115+
expected = series # Full selection returns the original structure
116+
pd.testing.assert_series_equal(result, expected)
117+
118+
119+
def test_nestedseries_getitem_masking():
120+
"""Test getitem with boolean masking in NestedSeries."""
121+
series = NestedSeries(
122+
data=[
123+
(np.array([1, 2]), np.array([0, 1])),
124+
(np.array([3, 4]), np.array([0, 1])),
125+
],
126+
index=[0, 1],
127+
dtype=NestedDtype(pa.struct([("a", pa.list_(pa.int64())), ("b", pa.list_(pa.int64()))])),
128+
name="nested",
129+
)
130+
131+
mask = pd.Series([True, False, False, True], index=[0, 0, 1, 1], dtype=bool, name="mask")
132+
result = series[mask]
133+
assert result.flat_length == 2
134+
135+
136+
def test_nestedseries_getitem_index():
137+
"""Test getitem with ordinary index selection in NestedSeries."""
138+
series = NestedSeries(
139+
data=[
140+
(np.array([1, 2]), np.array([0, 1])),
141+
(np.array([3, 4]), np.array([0, 1])),
142+
],
143+
index=[0, 1],
144+
dtype=NestedDtype(pa.struct([("a", pa.list_(pa.int64())), ("b", pa.list_(pa.int64()))])),
145+
)
146+
147+
result = series[0]
148+
expected = pd.DataFrame({"a": [1, 2], "b": [0, 1]}, index=[0, 1])
149+
pd.testing.assert_frame_equal(result, expected)
150+
151+
152+
def test_nestedseries_getitem_non_nested_dtype():
153+
"""Test setitem with a non-nested dtype."""
154+
series = NestedSeries(
155+
data=[1, 2, 3],
156+
index=[0, 1, 2],
157+
dtype=pd.ArrowDtype(pa.int64()),
158+
)
159+
160+
assert series[0] == 1
161+
162+
163+
def test_nestedseries_setitem_non_nested_dtype():
164+
"""Test setitem with a non-nested dtype."""
165+
series = NestedSeries(
166+
data=[1, 2, 3],
167+
index=[0, 1, 2],
168+
dtype=pd.ArrowDtype(pa.int64()),
169+
)
170+
171+
series[0] = 10
172+
assert series[0] == 10
173+
174+
87175
def test_nestedseries_to_flat():
88176
"""Test to_flat method of NestedSeries."""
89177
series = NestedSeries(

0 commit comments

Comments
 (0)