Skip to content

Commit 1e4b8a3

Browse files
committed
add more init paths for NestedDtype
1 parent dd891ae commit 1e4b8a3

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

src/nested_pandas/series/dtype.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,36 @@ class NestedDtype(ExtensionDtype):
2828
2929
Parameters
3030
----------
31-
pyarrow_dtype : pyarrow.StructType or pd.ArrowDtype
32-
The pyarrow data type to use for the nested type. It must be a struct
33-
type where all fields are list types.
31+
pyarrow_dtype : pyarrow.StructType, pd.ArrowDtype, or Mapping[str, pa.DataType]
32+
The pyarrow data type to use for the nested type. It may be provided as
33+
a pyarrow.StructType, a pandas.ArrowDtype, or a mapping of column names to
34+
pyarrow data types (such as a dictionary).
35+
36+
Examples
37+
--------
38+
>>> import pyarrow as pa
39+
>>> from nested_pandas import NestedDtype
40+
41+
From pa.StructType:
42+
43+
>>> dtype = NestedDtype(pa.struct([pa.field("a", pa.list_(pa.int64())),
44+
... pa.field("b", pa.list_(pa.float64()))]))
45+
>>> dtype
46+
nested<a: [int64], b: [double]>
47+
48+
From pd.ArrowDtype:
49+
50+
>>> import pandas as pd
51+
>>> dtype = NestedDtype(pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64())),
52+
... pa.field("b", pa.list_(pa.float64()))])))
53+
>>> dtype
54+
nested<a: [int64], b: [double]>
55+
56+
From mapping of column names to pyarrow data types:
57+
58+
>>> dtype = NestedDtype({"a": pa.int64(), "b": pa.float64()})
59+
>>> dtype
60+
nested<a: [int64], b: [double]>
3461
"""
3562

3663
# ExtensionDtype overrides #
@@ -160,6 +187,15 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray:
160187
pyarrow_dtype: pa.StructType
161188

162189
def __init__(self, pyarrow_dtype: pa.DataType) -> None:
190+
# Allow pd.ArrowDtypes on init
191+
if isinstance(pyarrow_dtype, pd.ArrowDtype):
192+
pyarrow_dtype = pyarrow_dtype.pyarrow_dtype
193+
194+
# Allow from_columns-style mapping inputs
195+
if isinstance(pyarrow_dtype, Mapping):
196+
pyarrow_dtype = pa.struct({col: pa.list_(pa_type) for col, pa_type in pyarrow_dtype.items()})
197+
pyarrow_dtype = cast(pa.StructType, pyarrow_dtype)
198+
163199
self.pyarrow_dtype, self.list_struct_pa_dtype = self._validate_dtype(pyarrow_dtype)
164200

165201
@property

tests/nested_pandas/series/test_dtype.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def test_from_pandas_arrow_dtype():
8383
assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])
8484

8585

86+
def test_init_from_pandas_arrow_dtype():
87+
"""Test that we can construct NestedDtype from pandas.ArrowDtype in __init__."""
88+
dtype_from_struct = NestedDtype(pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))])))
89+
assert dtype_from_struct.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])
90+
dtype_from_list = NestedDtype(pd.ArrowDtype(pa.list_(pa.struct([pa.field("a", pa.int64())]))))
91+
assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])
92+
93+
8694
def test_to_pandas_list_struct_arrow_dtype():
8795
"""Test that NestedDtype.to_pandas_arrow_dtype(list_struct=True) returns the correct pyarrow type."""
8896
dtype = NestedDtype.from_columns({"a": pa.list_(pa.int64()), "b": pa.float64()})
@@ -100,6 +108,15 @@ def test_from_columns():
100108
)
101109

102110

111+
def test_init_from_columns():
112+
"""Test NestedDtype.__init__ with columns dict."""
113+
columns = {"a": pa.int64(), "b": pa.float64()}
114+
dtype = NestedDtype(columns)
115+
assert dtype.pyarrow_dtype == pa.struct(
116+
[pa.field("a", pa.list_(pa.int64())), pa.field("b", pa.list_(pa.float64()))]
117+
)
118+
119+
103120
def test_na_value():
104121
"""Test that NestedDtype.na_value is a singleton instance of NAType."""
105122
dtype = NestedDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))]))

0 commit comments

Comments
 (0)