diff --git a/README.md b/README.md index 8c62eea..92654cb 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,7 @@ and the CLI. Here is a list of the available environment variables: | SPARK_ON_K8S_SERVICE_ACCOUNT | The service account to use | spark | | SPARK_ON_K8S_SPARK_CONF | The spark configuration to use | {} | | SPARK_ON_K8S_CLASS_NAME | The class name to use | | +| SPARK_ON_K8S_PACKAGES | The maven packages list to add to the classpath | | SPARK_ON_K8S_APP_ARGUMENTS | The arguments to pass to the app | [] | | SPARK_ON_K8S_APP_WAITER | The waiter to use to wait for the app to finish | no_wait | | SPARK_ON_K8S_IMAGE_PULL_POLICY | The image pull policy to use | IfNotPresent | diff --git a/spark_on_k8s/airflow/operators.py b/spark_on_k8s/airflow/operators.py index 1a1dbfe..6f3f0b6 100644 --- a/spark_on_k8s/airflow/operators.py +++ b/spark_on_k8s/airflow/operators.py @@ -56,6 +56,7 @@ class SparkOnK8SOperator(BaseOperator): the format %Y%m%d%H%M%S prefixed with a dash. spark_conf (dict[str, str], optional): Spark configuration. Defaults to None. class_name (str, optional): Spark application class name. Defaults to None. + packages: List of maven coordinates of jars to include in the classpath. Defaults to None. app_arguments (list[str], optional): Spark application arguments. Defaults to None. app_waiter (Literal["no_wait", "wait", "log"], optional): Spark application waiter. Defaults to "wait". @@ -124,6 +125,7 @@ def __init__( app_id_suffix: str = None, spark_conf: dict[str, str] | None = None, class_name: str | None = None, + packages: list[str] | None = None, app_arguments: list[str] | None = None, app_waiter: Literal["no_wait", "wait", "log"] = "wait", image_pull_policy: Literal["Always", "Never", "IfNotPresent"] = "IfNotPresent", @@ -160,6 +162,7 @@ def __init__( self.app_id_suffix = app_id_suffix self.spark_conf = spark_conf self.class_name = class_name + self.packages = packages self.app_arguments = app_arguments self.app_waiter = app_waiter self.image_pull_policy = image_pull_policy @@ -302,6 +305,7 @@ def _submit_new_job(self, context: Context): app_name=self.app_name, spark_conf=self.spark_conf, class_name=self.class_name, + packages=self.packages, app_arguments=self.app_arguments, app_waiter="no_wait", image_pull_policy=self.image_pull_policy, diff --git a/spark_on_k8s/client.py b/spark_on_k8s/client.py index 03f2b6d..68b163b 100644 --- a/spark_on_k8s/client.py +++ b/spark_on_k8s/client.py @@ -112,6 +112,7 @@ def submit_app( app_name: str | ArgNotSet = NOTSET, spark_conf: dict[str, str] | ArgNotSet = NOTSET, class_name: str | ArgNotSet = NOTSET, + packages: list[str] | ArgNotSet = NOTSET, app_arguments: list[str] | ArgNotSet = NOTSET, app_id_suffix: Callable[[], str] | ArgNotSet = NOTSET, app_waiter: Literal["no_wait", "wait", "log"] | ArgNotSet = NOTSET, @@ -148,6 +149,7 @@ def submit_app( `spark-app{app_id_suffix()}` spark_conf: Dictionary of spark configuration to pass to the application class_name: Name of the class to execute + packages: List of maven coordinates of jars to include in the classpath app_arguments: List of arguments to pass to the application app_id_suffix: Function to generate a suffix for the application ID, defaults to `default_app_id_suffix` @@ -204,6 +206,10 @@ def submit_app( spark_conf = Configuration.SPARK_ON_K8S_SPARK_CONF if class_name is NOTSET: class_name = Configuration.SPARK_ON_K8S_CLASS_NAME + if packages is NOTSET: + packages = ( + Configuration.SPARK_ON_K8S_PACKAGES.split(",") if Configuration.SPARK_ON_K8S_PACKAGES else [] + ) if app_arguments is NOTSET: app_arguments = Configuration.SPARK_ON_K8S_APP_ARGUMENTS if app_id_suffix is NOTSET: @@ -330,6 +336,8 @@ def submit_app( driver_command_args = ["driver", "--master", "k8s://https://kubernetes.default.svc.cluster.local:443"] if class_name: driver_command_args.extend(["--class", class_name]) + if packages: + driver_command_args.extend(["--packages", ",".join(packages)]) driver_command_args.extend( self._spark_config_to_arguments({**basic_conf, **spark_conf}) + [app_path, *main_class_parameters] ) diff --git a/spark_on_k8s/utils/configuration.py b/spark_on_k8s/utils/configuration.py index ee42f0d..55a7d04 100644 --- a/spark_on_k8s/utils/configuration.py +++ b/spark_on_k8s/utils/configuration.py @@ -16,6 +16,7 @@ class Configuration: SPARK_ON_K8S_APP_NAME = getenv("SPARK_ON_K8S_APP_NAME") SPARK_ON_K8S_SPARK_CONF = json.loads(getenv("SPARK_ON_K8S_SPARK_CONF", "{}")) SPARK_ON_K8S_CLASS_NAME = getenv("SPARK_ON_K8S_CLASS_NAME") + SPARK_ON_K8S_PACKAGES = getenv("SPARK_ON_K8S_PACKAGES", "") SPARK_ON_K8S_APP_ARGUMENTS = json.loads(getenv("SPARK_ON_K8S_APP_ARGUMENTS", "[]")) SPARK_ON_K8S_APP_WAITER = getenv("SPARK_ON_K8S_APP_WAITER", "no_wait") SPARK_ON_K8S_IMAGE_PULL_POLICY = getenv("SPARK_ON_K8S_IMAGE_PULL_POLICY", "IfNotPresent") diff --git a/tests/airflow/test_operators.py b/tests/airflow/test_operators.py index 012f7c2..acb8f03 100644 --- a/tests/airflow/test_operators.py +++ b/tests/airflow/test_operators.py @@ -32,6 +32,7 @@ def test_execute(self, mock_submit_app): app_arguments=["100000"], app_name="pyspark-job-example", service_account="spark", + packages=["some-package"], app_waiter="no_wait", driver_resources=PodResources(cpu=1, memory=1024, memory_overhead=512), executor_resources=PodResources(cpu=1, memory=1024, memory_overhead=512), @@ -76,6 +77,7 @@ def test_execute(self, mock_submit_app): ui_reverse_proxy=True, spark_conf=None, class_name=None, + packages=["some-package"], secret_values=None, volumes=None, driver_volume_mounts=None, @@ -202,6 +204,7 @@ def test_rendering_templates(self, mock_submit_app): "KEY2": "value from connection", }, class_name=None, + packages=None, volumes=None, driver_volume_mounts=None, executor_volume_mounts=None,