From 98330898d2575deb6bda36e071edaa383577b49c Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 17:49:39 +0800 Subject: [PATCH 1/5] init Signed-off-by: Weichen Xu --- joblibspark/backend.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index c3ea2ec..d8e987f 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -153,8 +153,15 @@ def _create_resource_profile(self, if self._support_stage_scheduling: self.using_stage_scheduling = True - default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1")) - default_gpus_per_task = int(self._spark.conf.get("spark.task.resource.gpu.amount", "0")) + if is_spark_connect_mode(): + # In Spark Connect mode, we can't read Spark cluster configures. + default_cpus_per_task = 1 + default_gpus_per_task = 0 + else: + default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1")) + default_gpus_per_task = int( + self._spark.conf.get("spark.task.resource.gpu.amount", "0") + ) num_cpus_per_spark_task = num_cpus_per_spark_task or default_cpus_per_task num_gpus_per_spark_task = num_gpus_per_spark_task or default_gpus_per_task From df93ee066de88e236d8bfcfb589f35b20b6078a9 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 18:00:00 +0800 Subject: [PATCH 2/5] nit Signed-off-by: Weichen Xu --- joblibspark/backend.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index d8e987f..eeb04d3 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -110,10 +110,6 @@ def __init__(self, self._is_spark_connect_mode = is_spark_connect_mode() if self._is_spark_connect_mode: - if Version(pyspark.__version__).major < 4: - raise RuntimeError( - "Joblib spark does not support Spark Connect with PySpark version < 4." - ) self._support_stage_scheduling = True self._spark_supports_job_cancelling = True else: From a30089fdc6781c8c9dfb739fafaeb80b8832fc61 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 18:24:33 +0800 Subject: [PATCH 3/5] fix test Signed-off-by: Weichen Xu --- joblibspark/backend.py | 10 +++++----- test/test_spark.py | 8 ++++++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index eeb04d3..b757240 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -334,14 +334,11 @@ def mapper_fn(_): return cloudpickle.loads(ser_res) try: - # pylint: disable=no-name-in-module,import-outside-toplevel - from pyspark import inheritable_thread_target - if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode(): # pylint: disable=fixme # TODO: remove this patch once Spark 4.0.0 is released. # the patch is for propagating the Spark session to current thread. - def patched_inheritable_thread_target(f): # pylint: disable=invalid-name + def inheritable_thread_target(f): # pylint: disable=invalid-name import functools import copy from typing import Any @@ -377,7 +374,10 @@ def inner(*args: Any, **kwargs: Any) -> Any: return outer - inheritable_thread_target = patched_inheritable_thread_target(self._spark) + inheritable_thread_target = inheritable_thread_target(self._spark) + else: + # pylint: disable=no-name-in-module,import-outside-toplevel + from pyspark import inheritable_thread_target run_on_worker_and_fetch_result = \ inheritable_thread_target(run_on_worker_and_fetch_result) diff --git a/test/test_spark.py b/test/test_spark.py index e871b45..6a776c9 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -203,8 +203,12 @@ def get_spark_context(x): assert len(taskcontext.resources().get("gpu").addresses) == 1 return TaskContext.get() - with parallel_backend('spark') as (ba, _): - Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10)) + if is_spark_connect_mode: + with parallel_backend('spark', num_gpus_per_spark_task=1) as (ba, _): + Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10)) + else: + with parallel_backend('spark') as (ba, _): + Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10)) def test_customized_resource_group(self): def get_spark_context(x): From 80981ea36b65e10d033cb5ccbfa465f449a07d80 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 18:30:25 +0800 Subject: [PATCH 4/5] format Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 6 +++--- joblibspark/backend.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 49da6ed..09fa802 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,9 +26,9 @@ jobs: - name: Install python packages run: | pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" pandas - - name: Run pylint - run: | - ./run-pylint.sh - name: Run test suites run: | TEST_SPARK_CONNECT=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh + - name: Run pylint + run: | + ./run-pylint.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index b757240..0d56032 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -339,9 +339,9 @@ def mapper_fn(_): # TODO: remove this patch once Spark 4.0.0 is released. # the patch is for propagating the Spark session to current thread. def inheritable_thread_target(f): # pylint: disable=invalid-name - import functools - import copy - from typing import Any + import functools # pylint: disable=C0415 + import copy # pylint: disable=C0415 + from typing import Any # pylint: disable=C0415 session = f assert session is not None, "Spark Connect session must be provided." @@ -359,6 +359,7 @@ def outer(ff: Any) -> Any: # pylint: disable=invalid-name @functools.wraps(ff) def inner(*args: Any, **kwargs: Any) -> Any: # Propagates the active spark session to the current thread + # pylint: disable=C0415 from pyspark.sql.connect.session import SparkSession as SCS # pylint: disable=protected-access,no-member @@ -376,8 +377,8 @@ def inner(*args: Any, **kwargs: Any) -> Any: inheritable_thread_target = inheritable_thread_target(self._spark) else: - # pylint: disable=no-name-in-module,import-outside-toplevel - from pyspark import inheritable_thread_target + # pylint: disable=no-name-in-module + from pyspark import inheritable_thread_target # pylint: disable=C0415 run_on_worker_and_fetch_result = \ inheritable_thread_target(run_on_worker_and_fetch_result) From 0734bd3d129a44b10096891db48df76f13575422 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 18:36:18 +0800 Subject: [PATCH 5/5] clean Signed-off-by: Weichen Xu --- joblibspark/backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 0d56032..61583f5 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -110,6 +110,10 @@ def __init__(self, self._is_spark_connect_mode = is_spark_connect_mode() if self._is_spark_connect_mode: + if Version(pyspark.__version__).major < 4: + raise RuntimeError( + "Joblib spark does not support Spark Connect with PySpark version < 4." + ) self._support_stage_scheduling = True self._spark_supports_job_cancelling = True else: