Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Literal arguments to UDFs cause incorrect behavior #3603

Open
kevinzwang opened this issue Dec 18, 2024 · 0 comments
Open

Literal arguments to UDFs cause incorrect behavior #3603

kevinzwang opened this issue Dec 18, 2024 · 0 comments
Labels
bug Something isn't working p2 Nice to have features

Comments

@kevinzwang
Copy link
Member

Describe the bug

When you use a literal column in a UDF, it will receive it as a series of a single value. Perhaps this is valid, but when you set batch_size for the UDF, this will actually fail to run altogether

To Reproduce

>>> import daft
>>> @daft.udf(return_dtype=daft.DataType.int64(), batch_size=1)
... def my_sum(a, b):
...     return a.sum(b)
...
>>> df = daft.from_pydict({"a": [1, 2, 3]})
>>> df = df.with_column("b", daft.lit(0))
>>> df.select(my_sum(df["a"], df["b"])).show()
Error when running pipeline node ProjectOperator
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 1
----> 1 df.select(my_sum(df["a"], df["b"])).show()

File ~/Desktop/Daft/daft/api_annotations.py:26, in DataframePublicAPI.<locals>._wrap(*args, **kwargs)
     24 type_check_function(func, *args, **kwargs)
     25 timed_method = time_df_method(func)
---> 26 return timed_method(*args, **kwargs)

File ~/Desktop/Daft/daft/analytics.py:199, in time_df_method.<locals>.tracked_method(*args, **kwargs)
    197 start = time.time()
    198 try:
--> 199     result = method(*args, **kwargs)
    200 except Exception as e:
    201     _ANALYTICS_CLIENT.track_df_method_call(
    202         method_name=method.__name__, duration_seconds=time.time() - start, error=str(type(e).__name__)
    203     )

File ~/Desktop/Daft/daft/dataframe/dataframe.py:2641, in DataFrame.show(self, n)
   2628 @DataframePublicAPI
   2629 def show(self, n: int = 8) -> None:
   2630     """Executes enough of the DataFrame in order to display the first ``n`` rows.
   2631 
   2632     If IPython is installed, this will use IPython's `display` utility to pretty-print in a
   (...)
   2639         n: number of rows to show. Defaults to 8.
   2640     """
-> 2641     dataframe_display = self._construct_show_display(n)
   2642     try:
   2643         from IPython.display import display

File ~/Desktop/Daft/daft/dataframe/dataframe.py:2598, in DataFrame._construct_show_display(self, n)
   2596 tables = []
   2597 seen = 0
-> 2598 for table in get_context().get_or_create_runner().run_iter_tables(builder, results_buffer_size=1):
   2599     tables.append(table)
   2600     seen += len(table)

File ~/Desktop/Daft/daft/runners/native_runner.py:89, in NativeRunner.run_iter_tables(self, builder, results_buffer_size)
     86 def run_iter_tables(
     87     self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None
     88 ) -> Iterator[MicroPartition]:
---> 89     for result in self.run_iter(builder, results_buffer_size=results_buffer_size):
     90         yield result.partition()

File ~/Desktop/Daft/daft/runners/native_runner.py:84, in NativeRunner.run_iter(self, builder, results_buffer_size)
     78 executor = NativeExecutor.from_logical_plan_builder(builder)
     79 results_gen = executor.run(
     80     {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()},
     81     daft_execution_config,
     82     results_buffer_size,
     83 )
---> 84 yield from results_gen

File ~/Desktop/Daft/daft/execution/native_executor.py:40, in <genexpr>(.0)
     35 from daft.runners.partitioning import LocalMaterializedResult
     37 psets_mp = {
     38     part_id: [part.micropartition()._micropartition for part in parts] for part_id, parts in psets.items()
     39 }
---> 40 return (
     41     LocalMaterializedResult(MicroPartition._from_pymicropartition(part))
     42     for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size)
     43 )

File ~/Desktop/Daft/daft/udf.py:140, in run_udf(func, bound_args, evaluated_expressions, py_return_dtype, batch_size)
    136 else:
    137     # all inputs must have the same lengths for batching
    138     # not sure this error can possibly be triggered but it's here
    139     if len(set(len(s) for s in evaluated_expressions)) != 1:
--> 140         raise RuntimeError(
    141             f"User-defined function `{func}` failed: cannot run in batches when inputs are different lengths: {tuple(len(series) for series in evaluated_expressions)}"
    142         )
    144     results = []
    145     for i in range(0, len(evaluated_expressions[0]), batch_size):

RuntimeError: User-defined function `<function my_sum at 0x14adc5440>` failed: cannot run in batches when inputs are different lengths: (3, 1)

Expected behavior

We should expect the UDF to work with a literal column and a batch_size set, and we might also want to consider broadcasting the literal value so that the UDF will always receive equal length series arguments.

Component(s)

Expressions, Other

Additional context

No response

@kevinzwang kevinzwang added bug Something isn't working p2 Nice to have features labels Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working p2 Nice to have features
Projects
None yet
Development

No branches or pull requests

1 participant