From 58d3a153f84b685a55385f5adb55da90605b0dfd Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 26 Mar 2025 18:17:03 +0800 Subject: [PATCH 01/13] init-s --- joblibspark/backend.py | 149 ++++++++++++++++++++++++++++------------- 1 file changed, 101 insertions(+), 48 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index a2a3728..581c839 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -77,20 +77,14 @@ def __init__(self, num_cpus_per_spark_task: Optional[int] = None, num_gpus_per_spark_task: Optional[int] = None, **backend_args): + from pyspark.sql.utils import is_remote + # pylint: disable=super-with-arguments super(SparkDistributedBackend, self).__init__(**backend_args) self._pool = None self._n_jobs = None self._spark = get_spark_session() - self._spark_context = self._spark.sparkContext self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4()) - self._spark_pinned_threads_enabled = isinstance( - self._spark_context._gateway, ClientServer - ) - self._spark_supports_job_cancelling = ( - self._spark_pinned_threads_enabled - or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup") - ) self._is_running = False try: from IPython import get_ipython # pylint: disable=import-outside-toplevel @@ -98,7 +92,21 @@ def __init__(self, except ImportError: self._ipython = None - self._support_stage_scheduling = self._is_support_stage_scheduling() + self._is_spark_connect_mode = is_remote() + if self._is_spark_connect_mode: + self._support_stage_scheduling = Version(pyspark.__version__).major >= 4 + self._spark_supports_job_cancelling = Version(pyspark.__version__) >= Version("3.5") + else: + self._spark_context = self._spark.sparkContext + self._spark_pinned_threads_enabled = isinstance( + self._spark_context._gateway, ClientServer + ) + self._spark_supports_job_cancelling = ( + self._spark_pinned_threads_enabled + or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup") + ) + self._support_stage_scheduling = self._is_support_stage_scheduling() + self._resource_profile = self._create_resource_profile(num_cpus_per_spark_task, num_gpus_per_spark_task) @@ -135,26 +143,48 @@ def _create_resource_profile(self, def _cancel_all_jobs(self): self._is_running = False if not self._spark_supports_job_cancelling: - # Note: There's bug existing in `sparkContext.cancelJobGroup`. - # See https://issues.apache.org/jira/browse/SPARK-31549 - warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, " - "so we could not terminate running spark jobs correctly. " - "See https://issues.apache.org/jira/browse/SPARK-31549 for reference.") + if self._is_spark_connect_mode: + warnings.warn("Spark connect does not support job cancellation API " + "for Spark version < 3.5") + else: + # Note: There's bug existing in `sparkContext.cancelJobGroup`. + # See https://issues.apache.org/jira/browse/SPARK-31549 + warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, " + "so we could not terminate running spark jobs correctly. " + "See https://issues.apache.org/jira/browse/SPARK-31549 for reference.") else: - self._spark.sparkContext.cancelJobGroup(self._job_group) + if self._is_spark_connect_mode: + self._spark.interruptTag(self._job_group) + else: + self._spark.sparkContext.cancelJobGroup(self._job_group) def effective_n_jobs(self, n_jobs): - max_num_concurrent_tasks = self._get_max_num_concurrent_tasks() if n_jobs is None: n_jobs = 1 - elif n_jobs == -1: - # n_jobs=-1 means requesting all available workers - n_jobs = max_num_concurrent_tasks - if n_jobs > max_num_concurrent_tasks: - warnings.warn(f"User-specified n_jobs ({n_jobs}) is greater than the max number of " - f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now." - "If dynamic allocation is enabled for the cluster, you might see more " - "executors allocated.") + + if self._is_spark_connect_mode: + if n_jobs == 1: + warnings.warn( + "The maximum number of concurrently running jobs is set to 1, " + "to increase concurrency, you need to set joblib spark backend " + "'n_jobs' param to a greater number." + ) + + if n_jobs == -1: + raise RuntimeError( + "In Spark connect mode, Joblib spark backend can't support setting " + "'n_jobs' to -1." + ) + else: + max_num_concurrent_tasks = self._get_max_num_concurrent_tasks() + if n_jobs == -1: + # n_jobs=-1 means requesting all available workers + n_jobs = max_num_concurrent_tasks + if n_jobs > max_num_concurrent_tasks: + warnings.warn(f"User-specified n_jobs ({n_jobs}) is greater than the max number of " + f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now." + "If dynamic allocation is enabled for the cluster, you might see more " + "executors allocated.") return n_jobs def _get_max_num_concurrent_tasks(self): @@ -213,32 +243,55 @@ def run_on_worker_and_fetch_result(): raise RuntimeError('The task is canceled due to ipython command canceled.') # TODO: handle possible spark exception here. # pylint: disable=fixme - worker_rdd = self._spark.sparkContext.parallelize([0], 1) - if self._resource_profile: - worker_rdd = worker_rdd.withResources(self._resource_profile) - def mapper_fn(_): - return cloudpickle.dumps(func()) - if self._spark_supports_job_cancelling: - if self._spark_pinned_threads_enabled: - self._spark.sparkContext.setLocalProperty( - "spark.jobGroup.id", - self._job_group - ) - self._spark.sparkContext.setLocalProperty( - "spark.job.description", - "joblib spark jobs" - ) - rdd = worker_rdd.map(mapper_fn) - ser_res = rdd.collect()[0] + if self._is_spark_connect_mode: + spark_df = self._spark.range(1, numPartitions=1) + + def mapper_fn(iterator): + import pandas as pd + + for _ in iterator: # consume input data. + pass + + result = cloudpickle.dumps(func()) + pd.DataFrame({"result": [result]}) + + if self._spark_supports_job_cancelling: + self._spark.addTag(self._job_group) + + ser_res = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect()[0].result + else: + worker_rdd = self._spark.sparkContext.parallelize([0], 1) + if self._resource_profile: + worker_rdd = worker_rdd.withResources(self._resource_profile) + + def mapper_fn(_): + return cloudpickle.dumps(func()) + + if self._spark_supports_job_cancelling: + if self._spark_pinned_threads_enabled: + self._spark.sparkContext.setLocalProperty( + "spark.jobGroup.id", + self._job_group + ) + self._spark.sparkContext.setLocalProperty( + "spark.job.description", + "joblib spark jobs" + ) + rdd = worker_rdd.map(mapper_fn) + ser_res = rdd.collect()[0] + else: + rdd = worker_rdd.map(mapper_fn) + ser_res = rdd.collectWithJobGroup( + self._job_group, + "joblib spark jobs" + )[0] else: rdd = worker_rdd.map(mapper_fn) - ser_res = rdd.collectWithJobGroup( - self._job_group, - "joblib spark jobs" - )[0] - else: - rdd = worker_rdd.map(mapper_fn) - ser_res = rdd.collect()[0] + ser_res = rdd.collect()[0] return cloudpickle.loads(ser_res) From 74356869e9e56692eaf25f428f6f1afd70d55151 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 6 Apr 2025 21:18:02 +0800 Subject: [PATCH 02/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 2 +- joblibspark/backend.py | 28 ++++++++++++++++++---------- requirements.txt | 2 ++ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3d7dd99..f720948 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,7 +38,7 @@ jobs: architecture: x64 - name: Install python packages run: | - pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint pyspark==${{ matrix.PYSPARK_VERSION }} + pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint pyspark==${{ matrix.PYSPARK_VERSION }} pandas - name: Run pylint run: | ./run-pylint.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 581c839..2a5f869 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -23,6 +23,7 @@ import uuid from typing import Optional from packaging.version import Version, parse +import pandas as pd from joblib.parallel \ import AutoBatchingMixin, ParallelBackendBase, register_parallel_backend, SequentialBackend @@ -62,6 +63,14 @@ def register(): register_parallel_backend('spark', SparkDistributedBackend) +def is_spark_connect_mode(): + try: + from pyspark.sql.utils import is_remote # pylint: disable=C0415 + return is_remote() + except ImportError: + return False + + # pylint: disable=too-many-instance-attributes class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin): """A ParallelBackend which will execute all batches on spark. @@ -77,8 +86,6 @@ def __init__(self, num_cpus_per_spark_task: Optional[int] = None, num_gpus_per_spark_task: Optional[int] = None, **backend_args): - from pyspark.sql.utils import is_remote - # pylint: disable=super-with-arguments super(SparkDistributedBackend, self).__init__(**backend_args) self._pool = None @@ -92,7 +99,7 @@ def __init__(self, except ImportError: self._ipython = None - self._is_spark_connect_mode = is_remote() + self._is_spark_connect_mode = is_spark_connect_mode() if self._is_spark_connect_mode: self._support_stage_scheduling = Version(pyspark.__version__).major >= 4 self._spark_supports_job_cancelling = Version(pyspark.__version__) >= Version("3.5") @@ -151,7 +158,8 @@ def _cancel_all_jobs(self): # See https://issues.apache.org/jira/browse/SPARK-31549 warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, " "so we could not terminate running spark jobs correctly. " - "See https://issues.apache.org/jira/browse/SPARK-31549 for reference.") + "See https://issues.apache.org/jira/browse/SPARK-31549 for " + "reference.") else: if self._is_spark_connect_mode: self._spark.interruptTag(self._job_group) @@ -181,10 +189,12 @@ def effective_n_jobs(self, n_jobs): # n_jobs=-1 means requesting all available workers n_jobs = max_num_concurrent_tasks if n_jobs > max_num_concurrent_tasks: - warnings.warn(f"User-specified n_jobs ({n_jobs}) is greater than the max number of " - f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now." - "If dynamic allocation is enabled for the cluster, you might see more " - "executors allocated.") + warnings.warn( + f"User-specified n_jobs ({n_jobs}) is greater than the max number of " + f"concurrent tasks ({max_num_concurrent_tasks}) this cluster can run now." + "If dynamic allocation is enabled for the cluster, you might see more " + "executors allocated." + ) return n_jobs def _get_max_num_concurrent_tasks(self): @@ -247,8 +257,6 @@ def run_on_worker_and_fetch_result(): spark_df = self._spark.range(1, numPartitions=1) def mapper_fn(iterator): - import pandas as pd - for _ in iterator: # consume input data. pass diff --git a/requirements.txt b/requirements.txt index 8a41183..bdc1ad1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ joblib>=0.14 packaging +pandas + From de1913855bda2ea989710bb5aaef788018acb986 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 6 Apr 2025 21:51:59 +0800 Subject: [PATCH 03/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 27 +++++---------------------- joblibspark/backend.py | 3 +++ test/test_spark.py | 30 ++++++++++++++++++++++++------ 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f720948..76b9610 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,27 +7,10 @@ jobs: fail-fast: false matrix: PYTHON_VERSION: ["3.10"] - JOBLIB_VERSION: ["1.2.0", "1.3.0"] - PIN_MODE: [false, true] - PYSPARK_VERSION: ["3.0.3", "3.1.3", "3.2.3", "3.3.2", "3.4.0"] - include: - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.11" - JOBLIB_VERSION: "1.3.0" - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.11" - JOBLIB_VERSION: "1.4.2" - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.12" - JOBLIB_VERSION: "1.3.0" - - PYSPARK_VERSION: "3.5.1" - PYTHON_VERSION: "3.12" - JOBLIB_VERSION: "1.4.2" - exclude: - - PYSPARK_VERSION: "3.0.3" - PIN_MODE: true - - PYSPARK_VERSION: "3.1.3" - PIN_MODE: true + JOBLIB_VERSION: ["1.3.2", "1.4.2"] + PIN_MODE: [true] + PYSPARK_VERSION: ["3.4.4", "3.5.5"] + SPARK_CONNECT_MODE: [false, true] name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, pin_mode ${{ matrix.PIN_MODE }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} steps: - uses: actions/checkout@v3 @@ -44,4 +27,4 @@ jobs: ./run-pylint.sh - name: Run test suites run: | - PYSPARK_PIN_THREAD=${{ matrix.PIN_MODE }} ./run-tests.sh + SPARK_CONNECT_MODE=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} PYSPARK_PIN_THREAD=${{ matrix.PIN_MODE }} ./run-tests.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 2a5f869..9af4105 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -64,6 +64,9 @@ def register(): def is_spark_connect_mode(): + """ + Check if running with spark connect mode. + """ try: from pyspark.sql.utils import is_remote # pylint: disable=C0415 return is_remote() diff --git a/test/test_spark.py b/test/test_spark.py index 9c78257..9cac25d 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging from time import sleep import pytest import os @@ -36,6 +37,9 @@ from pyspark.sql import SparkSession import pyspark +_logger = logging.getLogger("Test") +_logger.setLevel(logging.INFO) + register_spark() @@ -44,12 +48,26 @@ class TestSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): - cls.spark = ( - SparkSession.builder.master("local-cluster[1, 2, 1024]") - .config("spark.task.cpus", "1") - .config("spark.task.maxFailures", "1") - .getOrCreate() - ) + spark_version = os.environ["PYSPARK_VERSION"] + if os.environ["SPARK_CONNECT_MODE"].lower() == "true": + _logger.info("Test with spark connect mode.") + cls.spark = ( + SparkSession.builder.config( + "spark.jars.packages", f"org.apache.spark:spark-connect_2.12:{spark_version}" + ) + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .remote("local[2]") # Adjust the remote address if necessary + .appName(cls.__name__) + .getOrCreate() + ) + else: + cls.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]") + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .getOrCreate() + ) @classmethod def teardown_class(cls): From eb656c58a4e778d73c940dfdf532668278a6aeef Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sun, 6 Apr 2025 23:25:55 +0800 Subject: [PATCH 04/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 2 +- joblibspark/backend.py | 27 ++++++++++---- test/test_backend.py | 19 ++++++++-- test/test_spark.py | 72 ++++++++++++++++++++++++++------------ 4 files changed, 87 insertions(+), 33 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 76b9610..2e99d03 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,7 +21,7 @@ jobs: architecture: x64 - name: Install python packages run: | - pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint pyspark==${{ matrix.PYSPARK_VERSION }} pandas + pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" pandas - name: Run pylint run: | ./run-pylint.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 9af4105..22ca8ea 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -121,6 +121,9 @@ def __init__(self, num_gpus_per_spark_task) def _is_support_stage_scheduling(self): + if self._is_spark_connect_mode: + return Version(pyspark.__version__).major >= 4 + spark_master = self._spark_context.master is_spark_local_mode = spark_master == "local" or spark_master.startswith("local[") if is_spark_local_mode: @@ -264,16 +267,24 @@ def mapper_fn(iterator): pass result = cloudpickle.dumps(func()) - pd.DataFrame({"result": [result]}) + yield pd.DataFrame({"result": [result]}) if self._spark_supports_job_cancelling: self._spark.addTag(self._job_group) - ser_res = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - profile=self._resource_profile, - ).collect()[0].result + if self._support_stage_scheduling: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect() + else: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + ).collect() + pass + ser_res = bytes(collected[0].result) else: worker_rdd = self._spark.sparkContext.parallelize([0], 1) if self._resource_profile: @@ -309,6 +320,10 @@ def mapper_fn(_): try: # pylint: disable=no-name-in-module,import-outside-toplevel from pyspark import inheritable_thread_target + + if Version(pyspark.__version__).major >= 4: + inheritable_thread_target = inheritable_thread_target(self._spark) + run_on_worker_and_fetch_result = \ inheritable_thread_target(run_on_worker_and_fetch_result) except ImportError: diff --git a/test/test_backend.py b/test/test_backend.py index 1332e40..583af7f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -9,13 +9,26 @@ from joblibspark.backend import SparkDistributedBackend +spark_version = os.environ["PYSPARK_VERSION"] + class TestLocalSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): - cls.spark = ( - SparkSession.builder.master("local[*]").getOrCreate() - ) + if os.environ["SPARK_CONNECT_MODE"].lower() == "true": + cls.spark = ( + SparkSession.builder.config( + "spark.jars.packages", + f"org.apache.spark:spark-connect_2.12:{spark_version}" + ) + .remote("local[2]") # Adjust the remote address if necessary + .appName("Test") + .getOrCreate() + ) + else: + cls.spark = ( + SparkSession.builder.master("local[*]").getOrCreate() + ) @classmethod def teardown_class(cls): diff --git a/test/test_spark.py b/test/test_spark.py index 9cac25d..c45065d 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -42,30 +42,36 @@ register_spark() +spark_version = os.environ["PYSPARK_VERSION"] + +is_spark_connect_mode = os.environ["SPARK_CONNECT_MODE"].lower() == "true" + class TestSparkCluster(unittest.TestCase): spark = None @classmethod def setup_class(cls): - spark_version = os.environ["PYSPARK_VERSION"] + spark_session_builder = ( + SparkSession.builder + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + ) + if os.environ["SPARK_CONNECT_MODE"].lower() == "true": _logger.info("Test with spark connect mode.") cls.spark = ( - SparkSession.builder.config( - "spark.jars.packages", f"org.apache.spark:spark-connect_2.12:{spark_version}" + spark_session_builder.config( + "spark.jars.packages", + f"org.apache.spark:spark-connect_2.12:{spark_version}" ) - .config("spark.task.cpus", "1") - .config("spark.task.maxFailures", "1") .remote("local[2]") # Adjust the remote address if necessary - .appName(cls.__name__) + .appName("Test") .getOrCreate() ) else: cls.spark = ( - SparkSession.builder.master("local-cluster[1, 2, 1024]") - .config("spark.task.cpus", "1") - .config("spark.task.maxFailures", "1") + spark_session_builder.master("local-cluster[1, 2, 1024]") .getOrCreate() ) @@ -135,8 +141,12 @@ def test_fn(x): assert len(os.listdir(tmp_dir)) == 0 -@unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0), - "Resource group is only supported since spark 3.4.0") +@unittest.skipIf( + (not is_spark_connect_mode and Version(pyspark.__version__).release < (3, 4, 0)) or + (is_spark_connect_mode and Version(pyspark.__version__).major < 4), + "Resource group is only supported since Spark 3.4.0 for legacy Spark mode or " + "since Spark 4 for Spark Connect mode." +) class TestGPUSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): @@ -144,20 +154,36 @@ def setup_class(cls): os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh" ) - cls.spark = ( - SparkSession.builder.master("local-cluster[1, 2, 1024]") - .config("spark.task.cpus", "1") - .config("spark.task.resource.gpu.amount", "1") - .config("spark.executor.cores", "2") - .config("spark.worker.resource.gpu.amount", "2") - .config("spark.executor.resource.gpu.amount", "2") - .config("spark.task.maxFailures", "1") - .config( - "spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path - ) - .getOrCreate() + spark_session_builder = ( + SparkSession.builder + .config("spark.task.cpus", "1") + .config("spark.task.resource.gpu.amount", "1") + .config("spark.executor.cores", "2") + .config("spark.worker.resource.gpu.amount", "2") + .config("spark.executor.resource.gpu.amount", "2") + .config("spark.task.maxFailures", "1") + .config( + "spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path + ) ) + if os.environ["SPARK_CONNECT_MODE"].lower() == "true": + _logger.info("Test with spark connect mode.") + cls.spark = ( + spark_session_builder.config( + "spark.jars.packages", + f"org.apache.spark:spark-connect_2.12:{spark_version}" + ) + .remote("local[2]") # Adjust the remote address if necessary + .appName("Test") + .getOrCreate() + ) + else: + cls.spark = ( + spark_session_builder.master("local-cluster[1, 2, 1024]") + .getOrCreate() + ) + @classmethod def teardown_class(cls): cls.spark.stop() From 53715009a30843b4e3011f6aadbadebe6ef0e175 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 12:57:30 +0800 Subject: [PATCH 05/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 2 +- joblibspark/backend.py | 83 ++++++++++++++++++++++++++++++-------- test/test_backend.py | 27 +++++++++---- test/test_spark.py | 27 ++++++++----- 4 files changed, 105 insertions(+), 34 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2e99d03..1de03c9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,7 +9,7 @@ jobs: PYTHON_VERSION: ["3.10"] JOBLIB_VERSION: ["1.3.2", "1.4.2"] PIN_MODE: [true] - PYSPARK_VERSION: ["3.4.4", "3.5.5"] + PYSPARK_VERSION: ["3.4.4", "3.5.5", "4.0.0.dev2"] SPARK_CONNECT_MODE: [false, true] name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, pin_mode ${{ matrix.PIN_MODE }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} steps: diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 22ca8ea..848adbc 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -18,6 +18,7 @@ The joblib spark backend implementation. """ import atexit +import logging import warnings from multiprocessing.pool import ThreadPool import uuid @@ -48,6 +49,9 @@ from .utils import create_resource_profile, get_spark_session +_logger = logging.getLogger("joblibspark.backend") + + def register(): """ Register joblib spark backend. @@ -74,6 +78,9 @@ def is_spark_connect_mode(): return False +_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 64 + + # pylint: disable=too-many-instance-attributes class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin): """A ParallelBackend which will execute all batches on spark. @@ -185,9 +192,10 @@ def effective_n_jobs(self, n_jobs): ) if n_jobs == -1: - raise RuntimeError( - "In Spark connect mode, Joblib spark backend can't support setting " - "'n_jobs' to -1." + n_jobs = _DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE + _logger.warning( + "Joblib sets `n_jobs` to default value " + f"{_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE} in Spark Connect mode." ) else: max_num_concurrent_tasks = self._get_max_num_concurrent_tasks() @@ -272,18 +280,24 @@ def mapper_fn(iterator): if self._spark_supports_job_cancelling: self._spark.addTag(self._job_group) - if self._support_stage_scheduling: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - profile=self._resource_profile, - ).collect() - else: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - ).collect() - pass + try: + if self._support_stage_scheduling: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect() + else: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + ).collect() + pass + except Exception as e: + import traceback + with open("/tmp/err.log", "a") as f: + f.write(traceback.format_exc()) + ser_res = bytes(collected[0].result) else: worker_rdd = self._spark.sparkContext.parallelize([0], 1) @@ -321,7 +335,44 @@ def mapper_fn(_): # pylint: disable=no-name-in-module,import-outside-toplevel from pyspark import inheritable_thread_target - if Version(pyspark.__version__).major >= 4: + if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode(): + # TODO: remove this patch once Spark 4.0.0 is released. + def patched_inheritable_thread_target(f): + from pyspark.sql.utils import is_remote + import functools + import copy + from typing import Any + + session = f + assert session is not None, "Spark Connect session must be provided." + + def outer(ff: Any) -> Any: + session_client_thread_local_attrs = [ + (attr, copy.deepcopy(value)) + for ( + attr, + value, + ) in session.client.thread_local.__dict__.items() # type: ignore[union-attr] + ] + + @functools.wraps(ff) + def inner(*args: Any, **kwargs: Any) -> Any: + # Propagates the active spark session to the current thread + from pyspark.sql.connect.session import SparkSession as SCS + + SCS._set_default_and_active_session(session) + + # Set thread locals in child thread. + for attr, value in session_client_thread_local_attrs: + setattr(session.client.thread_local, attr, value) # type: ignore[union-attr] + return ff(*args, **kwargs) + + return inner + + return outer + + inheritable_thread_target = patched_inheritable_thread_target(self._spark) + else: inheritable_thread_target = inheritable_thread_target(self._spark) run_on_worker_and_fetch_result = \ diff --git a/test/test_backend.py b/test/test_backend.py index 583af7f..4aa4e70 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -8,8 +8,13 @@ from pyspark.sql import SparkSession from joblibspark.backend import SparkDistributedBackend +import joblibspark.backend + +joblibspark.backend._DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 8 + spark_version = os.environ["PYSPARK_VERSION"] +is_spark_connect_mode = os.environ["SPARK_CONNECT_MODE"].lower() == "true" class TestLocalSparkCluster(unittest.TestCase): @@ -21,7 +26,7 @@ def setup_class(cls): "spark.jars.packages", f"org.apache.spark:spark-connect_2.12:{spark_version}" ) - .remote("local[2]") # Adjust the remote address if necessary + .remote("local-cluster[1, 2, 1024]") .appName("Test") .getOrCreate() ) @@ -36,17 +41,23 @@ def teardown_class(cls): def test_effective_n_jobs(self): backend = SparkDistributedBackend() - max_num_concurrent_tasks = 8 - backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) assert backend.effective_n_jobs(n_jobs=None) == 1 - assert backend.effective_n_jobs(n_jobs=-1) == 8 assert backend.effective_n_jobs(n_jobs=4) == 4 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - assert backend.effective_n_jobs(n_jobs=16) == 16 - assert len(w) == 1 + if is_spark_connect_mode: + assert ( + backend.effective_n_jobs(n_jobs=-1) == + joblibspark.backend._DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE + ) + else: + max_num_concurrent_tasks = 8 + backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) + assert backend.effective_n_jobs(n_jobs=-1) == 8 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + assert backend.effective_n_jobs(n_jobs=16) == 16 + assert len(w) == 1 def test_resource_profile_supported(self): backend = SparkDistributedBackend() diff --git a/test/test_spark.py b/test/test_spark.py index c45065d..7e0ef6d 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -28,6 +28,7 @@ from joblib.parallel import Parallel, delayed, parallel_backend from joblibspark import register_spark +import joblibspark.backend from sklearn.utils import parallel_backend from sklearn.model_selection import cross_val_score @@ -40,12 +41,22 @@ _logger = logging.getLogger("Test") _logger.setLevel(logging.INFO) -register_spark() +joblibspark.backend._DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE = 2 + spark_version = os.environ["PYSPARK_VERSION"] is_spark_connect_mode = os.environ["SPARK_CONNECT_MODE"].lower() == "true" +if spark_version == "4.0.0.dev2": + spark_connect_jar = "org.apache.spark:spark-connect_2.13:4.0.0-preview2" +elif Version(spark_version).major < 4: + spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" +else: + raise RuntimeError("Unsupported Spark version.") + +register_spark() + class TestSparkCluster(unittest.TestCase): spark = None @@ -62,10 +73,9 @@ def setup_class(cls): _logger.info("Test with spark connect mode.") cls.spark = ( spark_session_builder.config( - "spark.jars.packages", - f"org.apache.spark:spark-connect_2.12:{spark_version}" + "spark.jars.packages", spark_connect_jar ) - .remote("local[2]") # Adjust the remote address if necessary + .remote("local-cluster[1, 2, 1024]") # Adjust the remote address if necessary .appName("Test") .getOrCreate() ) @@ -89,8 +99,8 @@ def slow_raise_value_error(condition, duration=0.05): raise ValueError("condition evaluated to True") with parallel_backend('spark') as (ba, _): - seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10)) - assert seq == [inc(i) for i in range(10)] + seq = Parallel(n_jobs=2)(delayed(inc)(i) for i in range(2)) + assert seq == [inc(i) for i in range(2)] with pytest.raises(BaseException): Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3) @@ -171,10 +181,9 @@ def setup_class(cls): _logger.info("Test with spark connect mode.") cls.spark = ( spark_session_builder.config( - "spark.jars.packages", - f"org.apache.spark:spark-connect_2.12:{spark_version}" + "spark.jars.packages", spark_connect_jar ) - .remote("local[2]") # Adjust the remote address if necessary + .remote("local-cluster[1, 2, 1024]") # Adjust the remote address if necessary .appName("Test") .getOrCreate() ) From 146e2e9b83980c6b373f38f36e0a12b4674d0f52 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 13:10:05 +0800 Subject: [PATCH 06/13] update Signed-off-by: Weichen Xu --- joblibspark/backend.py | 31 ++++++++++++------------------- test/test_backend.py | 18 ++++++++++-------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 848adbc..3e69d02 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -280,23 +280,18 @@ def mapper_fn(iterator): if self._spark_supports_job_cancelling: self._spark.addTag(self._job_group) - try: - if self._support_stage_scheduling: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - profile=self._resource_profile, - ).collect() - else: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - ).collect() - pass - except Exception as e: - import traceback - with open("/tmp/err.log", "a") as f: - f.write(traceback.format_exc()) + if self._support_stage_scheduling: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect() + else: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + ).collect() + pass ser_res = bytes(collected[0].result) else: @@ -372,8 +367,6 @@ def inner(*args: Any, **kwargs: Any) -> Any: return outer inheritable_thread_target = patched_inheritable_thread_target(self._spark) - else: - inheritable_thread_target = inheritable_thread_target(self._spark) run_on_worker_and_fetch_result = \ inheritable_thread_target(run_on_worker_and_fetch_result) diff --git a/test/test_backend.py b/test/test_backend.py index 4aa4e70..3d223ce 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -17,14 +17,21 @@ is_spark_connect_mode = os.environ["SPARK_CONNECT_MODE"].lower() == "true" +if spark_version == "4.0.0.dev2": + spark_connect_jar = "org.apache.spark:spark-connect_2.13:4.0.0-preview2" +elif Version(spark_version).major < 4: + spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" +else: + raise RuntimeError("Unsupported Spark version.") + + class TestLocalSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): if os.environ["SPARK_CONNECT_MODE"].lower() == "true": cls.spark = ( SparkSession.builder.config( - "spark.jars.packages", - f"org.apache.spark:spark-connect_2.12:{spark_version}" + "spark.jars.packages", spark_connect_jar ) .remote("local-cluster[1, 2, 1024]") .appName("Test") @@ -32,7 +39,7 @@ def setup_class(cls): ) else: cls.spark = ( - SparkSession.builder.master("local[*]").getOrCreate() + SparkSession.builder.master("local-cluster[1, 2, 1024]").getOrCreate() ) @classmethod @@ -59,11 +66,6 @@ def test_effective_n_jobs(self): assert backend.effective_n_jobs(n_jobs=16) == 16 assert len(w) == 1 - def test_resource_profile_supported(self): - backend = SparkDistributedBackend() - # The test fixture uses a local (standalone) Spark instance, which doesn't support stage-level scheduling. - assert not backend._support_stage_scheduling - class TestBasicSparkCluster(unittest.TestCase): @classmethod From 79050baa281e5596efa01a5cf8ab3fda93f7dbe1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 13:34:32 +0800 Subject: [PATCH 07/13] update Signed-off-by: Weichen Xu --- joblibspark/backend.py | 30 ++++++++++++++++++------------ run-pylint.sh | 2 +- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 3e69d02..cc4ce31 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -280,18 +280,23 @@ def mapper_fn(iterator): if self._spark_supports_job_cancelling: self._spark.addTag(self._job_group) - if self._support_stage_scheduling: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - profile=self._resource_profile, - ).collect() - else: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - ).collect() - pass + try: + if self._support_stage_scheduling: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect() + else: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + ).collect() + + except Exception as e: + with open("/tmp/err", "a") as f: + import traceback + f.write(traceback.format_exc()) ser_res = bytes(collected[0].result) else: @@ -332,6 +337,7 @@ def mapper_fn(_): if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode(): # TODO: remove this patch once Spark 4.0.0 is released. + # the patch is for propagating the Spark session to current thread. def patched_inheritable_thread_target(f): from pyspark.sql.utils import is_remote import functools diff --git a/run-pylint.sh b/run-pylint.sh index 91276d3..960c9a8 100755 --- a/run-pylint.sh +++ b/run-pylint.sh @@ -21,5 +21,5 @@ set -e # Run pylint -python -m pylint joblibspark +# python -m pylint joblibspark From 7a4a3b540e0c2fe2d5353d6706037bcd70e9ed2e Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 13:48:07 +0800 Subject: [PATCH 08/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 11 ++++++++--- joblibspark/backend.py | 36 +++++++++++++++++------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1de03c9..f4c0f50 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -8,10 +8,15 @@ jobs: matrix: PYTHON_VERSION: ["3.10"] JOBLIB_VERSION: ["1.3.2", "1.4.2"] - PIN_MODE: [true] PYSPARK_VERSION: ["3.4.4", "3.5.5", "4.0.0.dev2"] SPARK_CONNECT_MODE: [false, true] - name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, pin_mode ${{ matrix.PIN_MODE }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} + exclude: + - PYSPARK_VERSION: "3.4.4" + SPARK_CONNECT_MODE: true + - PYSPARK_VERSION: "3.5.5" + SPARK_CONNECT_MODE: true + + name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} steps: - uses: actions/checkout@v3 - name: Setup python ${{ matrix.PYTHON_VERSION }} @@ -27,4 +32,4 @@ jobs: ./run-pylint.sh - name: Run test suites run: | - SPARK_CONNECT_MODE=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} PYSPARK_PIN_THREAD=${{ matrix.PIN_MODE }} ./run-tests.sh + SPARK_CONNECT_MODE=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index cc4ce31..3d689cc 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -111,8 +111,12 @@ def __init__(self, self._is_spark_connect_mode = is_spark_connect_mode() if self._is_spark_connect_mode: - self._support_stage_scheduling = Version(pyspark.__version__).major >= 4 - self._spark_supports_job_cancelling = Version(pyspark.__version__) >= Version("3.5") + if Version(pyspark.__version__).major < 4: + raise RuntimeError( + "Joblib spark does not support Spark Connect with PySpark version < 4." + ) + self._support_stage_scheduling = True + self._spark_supports_job_cancelling = True else: self._spark_context = self._spark.sparkContext self._spark_pinned_threads_enabled = isinstance( @@ -280,23 +284,17 @@ def mapper_fn(iterator): if self._spark_supports_job_cancelling: self._spark.addTag(self._job_group) - try: - if self._support_stage_scheduling: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - profile=self._resource_profile, - ).collect() - else: - collected = spark_df.mapInPandas( - mapper_fn, - schema="result binary", - ).collect() - - except Exception as e: - with open("/tmp/err", "a") as f: - import traceback - f.write(traceback.format_exc()) + if self._support_stage_scheduling: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + profile=self._resource_profile, + ).collect() + else: + collected = spark_df.mapInPandas( + mapper_fn, + schema="result binary", + ).collect() ser_res = bytes(collected[0].result) else: From b8c6d7876f0724a0e88131bd3f676fac90a6d729 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 13:52:00 +0800 Subject: [PATCH 09/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f4c0f50..9314a51 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,12 +10,11 @@ jobs: JOBLIB_VERSION: ["1.3.2", "1.4.2"] PYSPARK_VERSION: ["3.4.4", "3.5.5", "4.0.0.dev2"] SPARK_CONNECT_MODE: [false, true] - exclude: - - PYSPARK_VERSION: "3.4.4" - SPARK_CONNECT_MODE: true - - PYSPARK_VERSION: "3.5.5" - SPARK_CONNECT_MODE: true - + exclude: + - PYSPARK_VERSION: "3.4.4" + SPARK_CONNECT_MODE: true + - PYSPARK_VERSION: "3.5.5" + SPARK_CONNECT_MODE: true name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} steps: - uses: actions/checkout@v3 From 64cb0a643a795ec89ef95e1fe7c8943c749065a8 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 13:56:15 +0800 Subject: [PATCH 10/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9314a51..85f275b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,7 +15,7 @@ jobs: SPARK_CONNECT_MODE: true - PYSPARK_VERSION: "3.5.5" SPARK_CONNECT_MODE: true - name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} + name: Run test on pyspark ${{ matrix.PYSPARK_VERSION }}, Use Spark Connect ${{ matrix.SPARK_CONNECT_MODE }}, joblib ${{ matrix.JOBLIB_VERSION }} steps: - uses: actions/checkout@v3 - name: Setup python ${{ matrix.PYTHON_VERSION }} From 189d4d299b9abfb9584884eb7682dc23f60e6714 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 17:18:49 +0800 Subject: [PATCH 11/13] update Signed-off-by: Weichen Xu --- .github/workflows/main.yml | 2 +- joblibspark/backend.py | 2 +- test/test_backend.py | 12 +++++------- test/test_spark.py | 14 ++++++-------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 85f275b..49da6ed 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,4 +31,4 @@ jobs: ./run-pylint.sh - name: Run test suites run: | - SPARK_CONNECT_MODE=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh + TEST_SPARK_CONNECT=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 3d689cc..80a9ebe 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -24,7 +24,6 @@ import uuid from typing import Optional from packaging.version import Version, parse -import pandas as pd from joblib.parallel \ import AutoBatchingMixin, ParallelBackendBase, register_parallel_backend, SequentialBackend @@ -275,6 +274,7 @@ def run_on_worker_and_fetch_result(): spark_df = self._spark.range(1, numPartitions=1) def mapper_fn(iterator): + import pandas as pd for _ in iterator: # consume input data. pass diff --git a/test/test_backend.py b/test/test_backend.py index 3d223ce..56fe575 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -14,21 +14,19 @@ spark_version = os.environ["PYSPARK_VERSION"] -is_spark_connect_mode = os.environ["SPARK_CONNECT_MODE"].lower() == "true" +is_spark_connect_mode = os.environ["TEST_SPARK_CONNECT"].lower() == "true" -if spark_version == "4.0.0.dev2": - spark_connect_jar = "org.apache.spark:spark-connect_2.13:4.0.0-preview2" -elif Version(spark_version).major < 4: - spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" +if Version(spark_version).major >= 4: + spark_connect_jar = "" else: - raise RuntimeError("Unsupported Spark version.") + spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" class TestLocalSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): - if os.environ["SPARK_CONNECT_MODE"].lower() == "true": + if is_spark_connect_mode: cls.spark = ( SparkSession.builder.config( "spark.jars.packages", spark_connect_jar diff --git a/test/test_spark.py b/test/test_spark.py index 7e0ef6d..e871b45 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -46,14 +46,12 @@ spark_version = os.environ["PYSPARK_VERSION"] -is_spark_connect_mode = os.environ["SPARK_CONNECT_MODE"].lower() == "true" +is_spark_connect_mode = os.environ["TEST_SPARK_CONNECT"].lower() == "true" -if spark_version == "4.0.0.dev2": - spark_connect_jar = "org.apache.spark:spark-connect_2.13:4.0.0-preview2" -elif Version(spark_version).major < 4: - spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" +if Version(spark_version).major >= 4: + spark_connect_jar = "" else: - raise RuntimeError("Unsupported Spark version.") + spark_connect_jar = f"org.apache.spark:spark-connect_2.12:{spark_version}" register_spark() @@ -69,7 +67,7 @@ def setup_class(cls): .config("spark.task.maxFailures", "1") ) - if os.environ["SPARK_CONNECT_MODE"].lower() == "true": + if is_spark_connect_mode: _logger.info("Test with spark connect mode.") cls.spark = ( spark_session_builder.config( @@ -177,7 +175,7 @@ def setup_class(cls): ) ) - if os.environ["SPARK_CONNECT_MODE"].lower() == "true": + if is_spark_connect_mode: _logger.info("Test with spark connect mode.") cls.spark = ( spark_session_builder.config( From fcccd180c749049ce298e9f64372a2e31726b364 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 17:28:08 +0800 Subject: [PATCH 12/13] format Signed-off-by: Weichen Xu --- joblibspark/backend.py | 16 ++++++++++------ run-pylint.sh | 3 +-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 80a9ebe..0a48947 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -196,6 +196,7 @@ def effective_n_jobs(self, n_jobs): if n_jobs == -1: n_jobs = _DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE + # pylint: disable = logging-fstring-interpolation _logger.warning( "Joblib sets `n_jobs` to default value " f"{_DEFAULT_N_JOBS_IN_SPARK_CONNECT_MODE} in Spark Connect mode." @@ -274,7 +275,7 @@ def run_on_worker_and_fetch_result(): spark_df = self._spark.range(1, numPartitions=1) def mapper_fn(iterator): - import pandas as pd + import pandas as pd # pylint: disable=import-outside-toplevel for _ in iterator: # consume input data. pass @@ -334,10 +335,10 @@ def mapper_fn(_): from pyspark import inheritable_thread_target if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode(): + # pylint: disable=fixme # TODO: remove this patch once Spark 4.0.0 is released. # the patch is for propagating the Spark session to current thread. - def patched_inheritable_thread_target(f): - from pyspark.sql.utils import is_remote + def patched_inheritable_thread_target(f): # pylint: disable=invalid-name import functools import copy from typing import Any @@ -345,13 +346,14 @@ def patched_inheritable_thread_target(f): session = f assert session is not None, "Spark Connect session must be provided." - def outer(ff: Any) -> Any: + def outer(ff: Any) -> Any: # pylint: disable=invalid-name session_client_thread_local_attrs = [ + # type: ignore[union-attr] (attr, copy.deepcopy(value)) for ( attr, value, - ) in session.client.thread_local.__dict__.items() # type: ignore[union-attr] + ) in session.client.thread_local.__dict__.items() ] @functools.wraps(ff) @@ -359,11 +361,13 @@ def inner(*args: Any, **kwargs: Any) -> Any: # Propagates the active spark session to the current thread from pyspark.sql.connect.session import SparkSession as SCS + # pylint: disable=protected-access SCS._set_default_and_active_session(session) # Set thread locals in child thread. for attr, value in session_client_thread_local_attrs: - setattr(session.client.thread_local, attr, value) # type: ignore[union-attr] + # type: ignore[union-attr] + setattr(session.client.thread_local, attr, value) return ff(*args, **kwargs) return inner diff --git a/run-pylint.sh b/run-pylint.sh index 960c9a8..60c6ce5 100755 --- a/run-pylint.sh +++ b/run-pylint.sh @@ -21,5 +21,4 @@ set -e # Run pylint -# python -m pylint joblibspark - +python -m pylint joblibspark From 934f1551dd4b206806aebe446c4d7254cff47d09 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 7 Apr 2025 17:34:03 +0800 Subject: [PATCH 13/13] format Signed-off-by: Weichen Xu --- joblibspark/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 0a48947..c3ea2ec 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -361,7 +361,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: # Propagates the active spark session to the current thread from pyspark.sql.connect.session import SparkSession as SCS - # pylint: disable=protected-access + # pylint: disable=protected-access,no-member SCS._set_default_and_active_session(session) # Set thread locals in child thread.