Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/xorq/backends/let/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.CountDistinctStar,
ops.DateDelta,
ops.RowID,
ops.Strftime,
ops.TimeDelta,
ops.TimestampDelta,
)
Expand Down Expand Up @@ -664,5 +663,8 @@ def visit_ArrayAny(self, op, *, arg):
),
)

def visit_Strftime(self, op, *, arg, format_str):
return self.f.temporal_strftime(arg, format_str)


compiler = DataFusionCompiler()
12 changes: 12 additions & 0 deletions python/xorq/backends/let/datafusion/udfs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import typing
from urllib.parse import parse_qs, urlsplit

import pyarrow as pa
Expand Down Expand Up @@ -129,3 +130,14 @@ def regex_split(s: str, pattern: str) -> list[str]:
)
pattern = patterns[0].as_py()
return pc.split_pattern_regex(s, pattern)


def temporal_strftime(array: dt.Timestamp(scale=9), pattern: dt.string) -> dt.string:
pattern, *_rest = typing.cast(pa.StringArray, pattern).unique().to_pylist()

if len(_rest) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use "truth value testing" here: if rest:
https://docs.python.org/3/library/stdtypes.html#truth-value-testing

raise com.XorqError(
"Only a single scalar pattern is supported for DataFusion strftime"
)

return pc.strftime(array, pattern)
18 changes: 18 additions & 0 deletions python/xorq/backends/let/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,24 @@ def test_timestamp_extract_week_of_year(alltypes, alltypes_df):
assert_series_equal(result, expected)


@pytest.mark.parametrize(
("expr_fn", "pandas_pattern"),
[
param(
lambda t: t.timestamp_col.strftime("%Y%m%d").name("formatted"),
"%Y%m%d",
id="literal_format_str",
),
],
)
def test_strftime(alltypes, alltypes_df, expr_fn, pandas_pattern):
expr = expr_fn(alltypes)
expected = alltypes_df.timestamp_col.dt.strftime(pandas_pattern).rename("formatted")

result = expr.execute()
assert_series_equal(result, expected)


PANDAS_UNITS = {
"m": "Min",
"ms": "L",
Expand Down