diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index f91e297539649..453fd3917ceb9 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -24,8 +24,9 @@ Upgrading from PySpark 4.0 to 4.1 * In Spark 4.1, Arrow-optimized Python UDF supports UDT input / output instead of falling back to the regular UDF. To restore the legacy behavior, set ``spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT`` to ``true``. -* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``. +* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled``. +* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``. Upgrading from PySpark 3.5 to 4.0 --------------------------------- diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index b9ed246753906..d73475c792375 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1050,6 +1050,11 @@ "Type hints for should be specified; however, got ." ] }, + "UDF_ARROW_TYPE_CONVERSION_ERROR": { + "message": [ + "Cannot convert the output value of the input '' with type '' to the specified return type of the column: ''. Please check if the data types match and try again." + ] + }, "UDF_RETURN_TYPE": { "message": [ "Return type of the user-defined function should be , but is ." diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b154318dc430c..b5a98bd75d6d2 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -131,6 +131,9 @@ def load_stream(self, stream): reader = pa.ipc.open_stream(stream) for batch in reader: yield batch + + def arrow_to_pandas(self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None): + return arrow_column.to_pylist() def __repr__(self): return "ArrowStreamSerializer" @@ -194,6 +197,133 @@ class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): def load_stream(self, stream): return ArrowStreamSerializer.load_stream(self, stream) +class ArrowBatchUDFSerializer(ArrowStreamSerializer): + """ + Serializer used by Python worker to evaluate Arrow UDFs + """ + + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + arrow_cast, + struct_in_pandas="row", + ndarray_as_list=False, + ): + super(ArrowBatchUDFSerializer, self).__init__() + self._timezone = timezone + self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = arrow_cast + self._struct_in_pandas = struct_in_pandas + self._ndarray_as_list = True + + def _create_array(self, arr, arrow_type, arrow_cast): + import pyarrow as pa + + assert(isinstance(arr, pa.Array), arr) + assert(isinstance(arrow_type, pa.DataType), arrow_type) + + try: + return arr + except pa.lib.ArrowException: + if arrow_cast: + return arr.cast(target_type=arrow_type, safe=self._safecheck) + else: + raise + + def load_stream(self, stream): + import pyarrow as pa + batches = super(ArrowBatchUDFSerializer, self).load_stream(stream) + for batch in batches: + table = pa.Table.from_batches([batch]) + columns = [ + self.arrow_to_pandas(c, i) + for i, c in enumerate(table.itercolumns()) + ] + yield columns + + def dump_stream(self, iterator, stream): + """ + Override because Arrow UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + import pyarrow as pa + + def wrap_and_init_stream(): + should_write_start_length = True + for packed in iterator: + # Flatten tuple of lists into a single list + if isinstance(packed, tuple) and all(isinstance(x, list) for x in packed): + packed = [item for sublist in packed for item in sublist] + + if isinstance(packed, tuple) and len(packed) == 2 and isinstance(packed[1], pa.DataType): + # single array UDF in a projection + arrs = [self._create_array(packed[0], packed[1], self._arrow_cast)] + elif isinstance(packed, list): + # multiple array UDFs in a projection + arrs = [self._create_array(t[0], t[1], self._arrow_cast) for t in packed] + elif isinstance(packed, tuple) and len(packed) == 3: + # single value UDF with type information + value, arrow_type, spark_type = packed + arr = pa.array(value, type=arrow_type) + arrs = [self._create_array(arr, arrow_type, self._arrow_cast)] + else: + arr = pa.array([packed], type=pa.int32()) + arrs = [self._create_array(arr, pa.int32(), self._arrow_cast)] + + batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + + # Write the first record batch with initialization. + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream) + + def __repr__(self): + return "ArrowBatchUDFSerializer" + + def arrow_to_pandas(self, arrow_column, idx, ndarray_as_list=False, spark_type=None): + import pyarrow.types as types + from pyspark.sql import Row + + if ( + self._struct_in_pandas == "row" + and types.is_struct(arrow_column.type) + and not is_variant(arrow_column.type) + ): + series = [ + super(ArrowBatchUDFSerializer, self) + .arrow_to_pandas( + column, + i, + self._struct_in_pandas, + self._ndarray_as_list, + spark_type=None, + ) + for i, (column, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type)) + ] + row_cls = Row(*[field.name for field in arrow_column.type]) + result = series[0].__class__([ + row_cls(*vals) for vals in zip(*series) + ]) + else: + result = ArrowStreamPandasSerializer.arrow_to_pandas( + self, + arrow_column, + idx, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=ndarray_as_list, + spark_type=spark_type, + ) + if spark_type is not None and hasattr(spark_type, "fromInternal"): + return result.apply(lambda v: spark_type.fromInternal(v) if v is not None else v) + return result + class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): """ diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py index 78e72e02836f9..c83e58529c108 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py @@ -244,6 +244,37 @@ def test_udf_use_arrow_and_session_conf(self): ) +@unittest.skipIf( + not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message +) +class ArrowPythonUDFLegacyTestsMixin(ArrowPythonUDFTestsMixin): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "true") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled") + finally: + super().tearDownClass() + + +class ArrowPythonUDFNonLegacyTestsMixin(ArrowPythonUDFTestsMixin): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "false") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled") + finally: + super().tearDownClass() + + class ArrowPythonUDFTests(ArrowPythonUDFTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -258,18 +289,32 @@ def tearDownClass(cls): super(ArrowPythonUDFTests, cls).tearDownClass() -class AsyncArrowPythonUDFTests(ArrowPythonUDFTests): +class ArrowPythonUDFLegacyTests(ArrowPythonUDFLegacyTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(AsyncArrowPythonUDFTests, cls).setUpClass() - cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4") + super(ArrowPythonUDFLegacyTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @classmethod def tearDownClass(cls): try: - cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level") + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + finally: + super(ArrowPythonUDFLegacyTests, cls).tearDownClass() + + +class ArrowPythonUDFNonLegacyTests(ArrowPythonUDFNonLegacyTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + super(ArrowPythonUDFNonLegacyTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(AsyncArrowPythonUDFTests, cls).tearDownClass() + super(ArrowPythonUDFNonLegacyTests, cls).tearDownClass() if __name__ == "__main__": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 67cf25cff6e6a..04561ab0c3084 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -52,7 +52,6 @@ from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, - ArrowStreamArrowUDFSerializer, ArrowStreamPandasUDTFSerializer, CogroupArrowUDFSerializer, CogroupPandasUDFSerializer, @@ -63,6 +62,8 @@ TransformWithStateInPandasInitStateSerializer, TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, + ArrowStreamArrowUDFSerializer, + ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_type, from_arrow_schema @@ -201,7 +202,126 @@ def verify_result_length(result, length): ) -def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): +def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf, input_types=None): + if use_legacy_pandas_udf_conversion(runner_conf): + return wrap_arrow_batch_udf_legacy(f, args_offsets, kwargs_offsets, return_type, runner_conf) + else: + return wrap_arrow_batch_udf_arrow(f, args_offsets, kwargs_offsets, return_type, runner_conf, input_types) + + +def wrap_arrow_batch_udf_arrow(f, args_offsets, kwargs_offsets, return_type, runner_conf, input_types=None): + import pyarrow as pa + + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) + zero_arg_exec = False + if len(args_kwargs_offsets) == 0: + args_kwargs_offsets = (0,) + zero_arg_exec = True + + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) + + result_func = lambda pdf: pdf # noqa: E731 + if type(return_type) == StringType: + result_func = lambda r: str(r) if r is not None else r # noqa: E731 + elif type(return_type) == BinaryType: + result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + elif hasattr(return_type, "fromInternal"): + result_func = lambda r: return_type.fromInternal(r) if r is not None else r # noqa: E731 + + # Get input types for UDT conversion + def convert_input_value(value, input_type): + """Convert Arrow/numpy input values to appropriate Python types for UDTs""" + if input_type is not None and hasattr(input_type, "fromInternal"): + if value is not None: + # Convert numpy array to tuple for UDT fromInternal + if hasattr(value, "tolist"): + return input_type.fromInternal(tuple(value.tolist())) + else: + return input_type.fromInternal(value) + return value + + if zero_arg_exec: + def get_args(*args: pa.RecordBatch): + if len(args) > 0: + if hasattr(args[0], "num_rows"): + batch_size = args[0].num_rows + else: + batch_size = len(args[0]) + else: + batch_size = 1 + return [() for _ in range(batch_size)] + else: + def get_args(*args: pa.RecordBatch): + arrays = [ + arg.combine_chunks() if isinstance(arg, pa.ChunkedArray) else arg + for arg in args + ] + + def get_python_value(array, index): + """Get Python value from array at index, handling both PyArrow and numpy arrays""" + value = array[index] + if hasattr(value, 'as_py'): + # PyArrow scalar + return value.as_py() + elif hasattr(value, 'item') and value.size == 1: + # Single-element numpy array + return value.item() + elif hasattr(value, 'tolist'): + # Multi-element numpy array (for UDT) + return tuple(value.tolist()) + else: + # Other Python object + return value + + if input_types is not None: + # branch for UDF conversion + converted_args = [] + for i in range(len(arrays[0]) if arrays else 0): + row_values = [] + for j, array in enumerate(arrays): + py_value = get_python_value(array, i) + converted_value = convert_input_value(py_value, input_types[j] if j < len(input_types) else None) + row_values.append(converted_value) + converted_args.append(tuple(row_values)) + return converted_args + else: + raw_args = [] + for i in range(len(arrays[0]) if arrays else 0): + row_values = tuple(get_python_value(array, i) for array in arrays) + raw_args.append(row_values) + return raw_args + + @fail_on_stopiteration + def evaluate(*args: pa.RecordBatch): + results = [result_func(func(*row)) for row in get_args(*args)] + try: + arr = pa.array(results, type=arrow_return_type) + except (pa.lib.ArrowInvalid, pa.lib.ArrowTypeError) as e: + raise PySparkRuntimeError( + errorClass="UDF_ARROW_TYPE_CONVERSION_ERROR", + messageParameters={ + "value": str(results), + "arrow_type": str(arrow_return_type), + "error": str(e), + }, + ) from e + return (arr, arrow_return_type) + + def make_output(*a): + out = evaluate(*a) + if isinstance(out, list): + return out + else: + return [out] + + return ( + args_kwargs_offsets, + make_output, + ) + +def wrap_arrow_batch_udf_legacy(f, args_offsets, kwargs_offsets, return_type, runner_conf): import pandas as pd func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) @@ -961,7 +1081,7 @@ def profiling_func(*args, **kwargs): return profiling_func -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profiler): +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profiler, input_types=None): num_arg = read_int(infile) if eval_type in ( @@ -973,7 +1093,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, # The below doesn't support named argument, but shares the same protocol. PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, + PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF ): args_offsets = [] kwargs_offsets = {} @@ -1029,7 +1149,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF: return wrap_scalar_arrow_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: - return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) + return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf, input_types) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF: @@ -1097,6 +1217,9 @@ def assign_cols_by_name(runner_conf): def use_large_var_types(runner_conf): return runner_conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false").lower() == "true" +def use_legacy_pandas_udf_conversion(runner_conf): + return runner_conf.get("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "false").lower() == "true" + # Read and process a serialized user-defined table function (UDTF) from a socket. # It expects the UDTF to be in a specific format and performs various checks to @@ -1804,6 +1927,8 @@ def read_udfs(pickleSer, infile, eval_type): state_server_port = None key_schema = None + input_types = None + if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, @@ -1914,6 +2039,13 @@ def read_udfs(pickleSer, infile, eval_type): ): # Arrow cast for type coercion is disabled by default ser = ArrowStreamArrowUDFSerializer(timezone, safecheck, _assign_cols_by_name, False) + elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not use_legacy_pandas_udf_conversion(runner_conf): + input_types = ( + [f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile))] + if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + else None + ) + ser = ArrowBatchUDFSerializer(timezone, safecheck, _assign_cols_by_name, False, "row") else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. @@ -1975,7 +2107,7 @@ def read_udfs(pickleSer, infile, eval_type): assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." arg_offsets, udf = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=input_types ) def func(_, iterator): @@ -2068,7 +2200,7 @@ def extract_key_value_indexes(grouped_arg_offsets): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -2087,7 +2219,7 @@ def mapper(a): # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] @@ -2118,7 +2250,7 @@ def values_gen(): # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) # parsed offsets: # [ @@ -2156,7 +2288,7 @@ def values_gen(): # See TransformWithStateInPySparkExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] @@ -2183,7 +2315,7 @@ def mapper(a): # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) # parsed offsets: # [ @@ -2218,7 +2350,7 @@ def mapper(a): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -2244,7 +2376,7 @@ def mapper(a): # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -2280,7 +2412,7 @@ def mapper(a): # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -2299,7 +2431,7 @@ def mapper(a): # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler, input_types=None ) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -2325,11 +2457,13 @@ def mapper(a): for i in range(num_udfs): udfs.append( read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler + pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler, input_types=input_types ) ) def mapper(a): + if hasattr(a, 'num_columns') and a.num_columns == 0: + return None result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a322ec8a7e215..f519016e74d88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3792,6 +3792,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PYTHON_UDF_LEGACY_PANDAS_CONVERSION_ENABLED = + buildConf("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled") + .internal() + .doc(s"When true and ${PYTHON_UDF_ARROW_ENABLED.key} is enabled, extra pandas " + + "conversion happens during (de)serialization between JVM and Python workers. " + + "This matters especially when the produced output has a schema different from " + + "specified schema, resulting in a different type coercion.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val PYTHON_PLANNER_EXEC_MEMORY = buildConf("spark.sql.planner.pythonExecution.memory") .doc("Specifies the memory allocation for executing Python code in Spark driver, in MiB. " + @@ -6752,6 +6763,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyPandasConversion: Boolean = getConf(PYTHON_TABLE_UDF_LEGACY_PANDAS_CONVERSION_ENABLED) + def legacyPandasConversionUDF: Boolean = getConf(PYTHON_UDF_LEGACY_PANDAS_CONVERSION_ENABLED) + def pythonPlannerExecMemory: Option[Long] = getConf(PYTHON_PLANNER_EXEC_MEMORY) def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 9a9fb574b87fb..920cfe76d8c34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -131,7 +131,11 @@ object ArrowPythonRunner { val legacyPandasConversion = Seq( SQLConf.PYTHON_TABLE_UDF_LEGACY_PANDAS_CONVERSION_ENABLED.key -> conf.legacyPandasConversion.toString) + val legacyPandasConversionUDF = Seq( + SQLConf.PYTHON_UDF_LEGACY_PANDAS_CONVERSION_ENABLED.key -> + conf.legacyPandasConversionUDF.toString) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ - arrowAyncParallelism ++ useLargeVarTypes ++ legacyPandasConversion: _*) + arrowAyncParallelism ++ useLargeVarTypes ++ + legacyPandasConversion ++ legacyPandasConversionUDF: _*) } }