Skip to content

[DRAFT][PYTHON] Improve Python UDF Arrow Serializer Performance #51225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 40 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/docs/source/migration_guide/pyspark_upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ Upgrading from PySpark 4.0 to 4.1

* In Spark 4.1, Arrow-optimized Python UDF supports UDT input / output instead of falling back to the regular UDF. To restore the legacy behavior, set ``spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT`` to ``true``.

* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``.
* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled``.

* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``.

Upgrading from PySpark 3.5 to 4.0
---------------------------------
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,11 @@
"Type hints for <target> should be specified; however, got <sig>."
]
},
"UDF_ARROW_TYPE_CONVERSION_ERROR": {
"message": [
"Cannot convert the output value of the input '<data>' with type '<schema>' to the specified return type of the column: '<arrow_schema>'. Please check if the data types match and try again."
]
},
"UDF_RETURN_TYPE": {
"message": [
"Return type of the user-defined function should be <expected>, but is <actual>."
Expand Down
130 changes: 130 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def load_stream(self, stream):
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch

def arrow_to_pandas(self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None):
return arrow_column.to_pylist()

def __repr__(self):
return "ArrowStreamSerializer"
Expand Down Expand Up @@ -194,6 +197,133 @@ class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
def load_stream(self, stream):
return ArrowStreamSerializer.load_stream(self, stream)

class ArrowBatchUDFSerializer(ArrowStreamSerializer):
"""
Serializer used by Python worker to evaluate Arrow UDFs
"""

def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
arrow_cast,
struct_in_pandas="row",
ndarray_as_list=False,
):
super(ArrowBatchUDFSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
self._arrow_cast = arrow_cast
self._struct_in_pandas = struct_in_pandas
self._ndarray_as_list = True

def _create_array(self, arr, arrow_type, arrow_cast):
import pyarrow as pa

assert(isinstance(arr, pa.Array), arr)
assert(isinstance(arrow_type, pa.DataType), arrow_type)

try:
return arr
except pa.lib.ArrowException:
if arrow_cast:
return arr.cast(target_type=arrow_type, safe=self._safecheck)
else:
raise

def load_stream(self, stream):
import pyarrow as pa
batches = super(ArrowBatchUDFSerializer, self).load_stream(stream)
for batch in batches:
table = pa.Table.from_batches([batch])
columns = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(table.itercolumns())
]
yield columns

def dump_stream(self, iterator, stream):
"""
Override because Arrow UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""
import pyarrow as pa

def wrap_and_init_stream():
should_write_start_length = True
for packed in iterator:
# Flatten tuple of lists into a single list
if isinstance(packed, tuple) and all(isinstance(x, list) for x in packed):
packed = [item for sublist in packed for item in sublist]

if isinstance(packed, tuple) and len(packed) == 2 and isinstance(packed[1], pa.DataType):
# single array UDF in a projection
arrs = [self._create_array(packed[0], packed[1], self._arrow_cast)]
elif isinstance(packed, list):
# multiple array UDFs in a projection
arrs = [self._create_array(t[0], t[1], self._arrow_cast) for t in packed]
elif isinstance(packed, tuple) and len(packed) == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the conditions in arrow-opt UDF is more complicated.
In what case will this branch be chosen?

# single value UDF with type information
value, arrow_type, spark_type = packed
arr = pa.array(value, type=arrow_type)
arrs = [self._create_array(arr, arrow_type, self._arrow_cast)]
else:
arr = pa.array([packed], type=pa.int32())
arrs = [self._create_array(arr, pa.int32(), self._arrow_cast)]

batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

# Write the first record batch with initialization.
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream)

def __repr__(self):
return "ArrowBatchUDFSerializer"

def arrow_to_pandas(self, arrow_column, idx, ndarray_as_list=False, spark_type=None):
import pyarrow.types as types
from pyspark.sql import Row

if (
self._struct_in_pandas == "row"
and types.is_struct(arrow_column.type)
and not is_variant(arrow_column.type)
):
series = [
super(ArrowBatchUDFSerializer, self)
.arrow_to_pandas(
column,
i,
self._struct_in_pandas,
self._ndarray_as_list,
spark_type=None,
)
for i, (column, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type))
]
row_cls = Row(*[field.name for field in arrow_column.type])
result = series[0].__class__([
row_cls(*vals) for vals in zip(*series)
])
else:
result = ArrowStreamPandasSerializer.arrow_to_pandas(
self,
arrow_column,
idx,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=ndarray_as_list,
spark_type=spark_type,
)
if spark_type is not None and hasattr(spark_type, "fromInternal"):
return result.apply(lambda v: spark_type.fromInternal(v) if v is not None else v)
return result


class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
"""
Expand Down
55 changes: 50 additions & 5 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,37 @@ def test_udf_use_arrow_and_session_conf(self):
)


@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
class ArrowPythonUDFLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
finally:
super().tearDownClass()


class ArrowPythonUDFNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "false")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
finally:
super().tearDownClass()


class ArrowPythonUDFTests(ArrowPythonUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -258,18 +289,32 @@ def tearDownClass(cls):
super(ArrowPythonUDFTests, cls).tearDownClass()


class AsyncArrowPythonUDFTests(ArrowPythonUDFTests):
class ArrowPythonUDFLegacyTests(ArrowPythonUDFLegacyTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(AsyncArrowPythonUDFTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4")
super(ArrowPythonUDFLegacyTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level")
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(ArrowPythonUDFLegacyTests, cls).tearDownClass()


class ArrowPythonUDFNonLegacyTests(ArrowPythonUDFNonLegacyTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(ArrowPythonUDFNonLegacyTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(AsyncArrowPythonUDFTests, cls).tearDownClass()
super(ArrowPythonUDFNonLegacyTests, cls).tearDownClass()


if __name__ == "__main__":
Expand Down
Loading