Skip to content

Commit 859acb4

Browse files
authored
feat: make cast accept built-in Python types (#858)
1 parent fe0738a commit 859acb4

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

python/datafusion/expr.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
functions as functions_internal,
2929
)
3030
from datafusion.common import NullTreatment, RexType, DataTypeMap
31-
from typing import Any, Optional
31+
from typing import Any, Optional, Type
3232
import pyarrow as pa
3333

3434
# The following are imported from the internal representation. We may choose to
@@ -372,8 +372,25 @@ def is_not_null(self) -> Expr:
372372
"""Returns ``True`` if this expression is not null."""
373373
return Expr(self.expr.is_not_null())
374374

375-
def cast(self, to: pa.DataType[Any]) -> Expr:
375+
_to_pyarrow_types = {
376+
float: pa.float64(),
377+
int: pa.int64(),
378+
str: pa.string(),
379+
bool: pa.bool_(),
380+
}
381+
382+
def cast(
383+
self, to: pa.DataType[Any] | Type[float] | Type[int] | Type[str] | Type[bool]
384+
) -> Expr:
376385
"""Cast to a new data type."""
386+
if not isinstance(to, pa.DataType):
387+
try:
388+
to = self._to_pyarrow_types[to]
389+
except KeyError:
390+
raise TypeError(
391+
"Expected instance of pyarrow.DataType or builtins.type"
392+
)
393+
377394
return Expr(self.expr.cast(to))
378395

379396
def rex_type(self) -> RexType:

python/datafusion/tests/test_functions.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def df():
4444
datetime(2020, 7, 2),
4545
]
4646
),
47+
pa.array([False, True, True]),
4748
],
48-
names=["a", "b", "c", "d"],
49+
names=["a", "b", "c", "d", "e"],
4950
)
5051
return ctx.create_dataframe([[batch]])
5152

@@ -63,15 +64,14 @@ def test_named_struct(df):
6364
)
6465

6566
expected = """DataFrame()
66-
+-------+---+---------+------------------------------+
67-
| a | b | c | d |
68-
+-------+---+---------+------------------------------+
69-
| Hello | 4 | hello | {a: Hello, b: 4, c: hello } |
70-
| World | 5 | world | {a: World, b: 5, c: world } |
71-
| ! | 6 | ! | {a: !, b: 6, c: !} |
72-
+-------+---+---------+------------------------------+
67+
+-------+---+---------+------------------------------+-------+
68+
| a | b | c | d | e |
69+
+-------+---+---------+------------------------------+-------+
70+
| Hello | 4 | hello | {a: Hello, b: 4, c: hello } | false |
71+
| World | 5 | world | {a: World, b: 5, c: world } | true |
72+
| ! | 6 | ! | {a: !, b: 6, c: !} | true |
73+
+-------+---+---------+------------------------------+-------+
7374
""".strip()
74-
7575
assert str(df) == expected
7676

7777

@@ -978,3 +978,22 @@ def test_binary_string_functions(df):
978978
assert pa.array(result.column(1)).cast(pa.string()) == pa.array(
979979
["Hello", "World", "!"]
980980
)
981+
982+
983+
@pytest.mark.parametrize(
984+
"python_datatype, name, expected",
985+
[
986+
pytest.param(bool, "e", pa.bool_(), id="bool"),
987+
pytest.param(int, "b", pa.int64(), id="int"),
988+
pytest.param(float, "b", pa.float64(), id="float"),
989+
pytest.param(str, "b", pa.string(), id="str"),
990+
],
991+
)
992+
def test_cast(df, python_datatype, name: str, expected):
993+
df = df.select(
994+
column(name).cast(python_datatype).alias("actual"),
995+
column(name).cast(expected).alias("expected"),
996+
)
997+
result = df.collect()
998+
result = result[0]
999+
assert result.column(0) == result.column(1)

0 commit comments

Comments
 (0)