From 12595344f98bd061c9187435260ce0d6dfa44728 Mon Sep 17 00:00:00 2001 From: Partho Sarthi Date: Thu, 2 Jan 2025 14:36:53 -0800 Subject: [PATCH] Minor refactoring Signed-off-by: Partho Sarthi --- .../spark/rapids/tool/tuning/AutoTuner.scala | 69 +++++++++---------- .../tool/tuning/ProfilingAutoTunerSuite.scala | 25 +++---- 2 files changed, 44 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala index f05ab77a6..c10f0f099 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -275,15 +275,15 @@ class RecommendationEntry(val name: String, */ // scalastyle:on line.size.limit object ShuffleManagerResolver { - // Databricks version to RapidsShuffleManager version mapping. - private val DatabricksVersionMap = Map( + // Supported Databricks version to RapidsShuffleManager version mapping. + private val supportedDatabricksVersionMap = Array( "11.3" -> "330db", "12.3" -> "332db", "13.3" -> "341db" ) - // Spark version to RapidsShuffleManager version mapping. - private val SparkVersionMap = Map( + // Supported Spark version to RapidsShuffleManager version mapping. + private val supportedSparkVersionMap = Array( "3.2.0" -> "320", "3.2.1" -> "321", "3.2.2" -> "322", @@ -299,8 +299,7 @@ object ShuffleManagerResolver { "3.4.2" -> "342", "3.4.3" -> "343", "3.5.0" -> "350", - "3.5.1" -> "351", - "4.0.0" -> "400" + "3.5.1" -> "351" ) def buildShuffleManagerClassName(smVersion: String): String = { @@ -322,36 +321,35 @@ object ShuffleManagerResolver { * * Example: * sparkVersion: "3.2.0-amzn-1" - * versionMap: {"3.2.0" -> "320", "3.2.1" -> "321"} - * Then, smVersion: "320" + * supportedVersionsMap: ["3.2.0" -> "320", "3.2.1" -> "321"] + * return: Right("com.nvidia.spark.rapids.spark320.RapidsShuffleManager") * - * sparkVersion: "13.3-ml-1" - * versionMap: {"11.3" -> "330db", "12.3" -> "332db", "13.3" -> "341db"} - * Then, smVersion: "341db" + * sparkVersion: "13.3.x-gpu-ml-scala2.12" + * supportedVersionsMap: ["11.3" -> "330db", "12.3" -> "332db", "13.3" -> "341db"] + * return: Right("com.nvidia.spark.rapids.spark341db.RapidsShuffleManager") * * sparkVersion: "3.1.2" - * versionMap: {"3.2.0" -> "320", "3.2.1" -> "321"} - * Then, smVersion: None + * supportedVersionsMap: ["3.2.0" -> "320", "3.2.1" -> "321"] + * return: Left("Could not recommend RapidsShuffleManager as the provided version + * 3.1.2 is not supported.") * * @return Either an error message (Left) or the RapidsShuffleManager class name (Right) */ private def getClassNameInternal( - versionMap: Map[String, String], sparkVersion: String): Either[String, String] = { - val smVersionOpt = versionMap.collectFirst { - case (key, value) if sparkVersion.contains(key) => value - } - smVersionOpt match { - case Some(smVersion) => - Right(buildShuffleManagerClassName(smVersion)) - case None => - Left(commentForUnsupportedVersion(sparkVersion)) + supportedVersionsMap: Array[(String, String)], + sparkVersion: String): Either[String, String] = { + supportedVersionsMap.collectFirst { + case (supportedVersion, smVersion) if sparkVersion.contains(supportedVersion) => smVersion + } match { + case Some(smVersion) => Right(buildShuffleManagerClassName(smVersion)) + case None => Left(commentForUnsupportedVersion(sparkVersion)) } } /** - * Determines the appropriate RapidsShuffleManager class name based on the provided versions. - * Databricks version takes precedence over Spark version. If a valid class name is not found, - * an error message is returned. + * Determines the appropriate RapidsShuffleManager class name based on the provided Databricks or + * Spark version. Databricks version takes precedence over Spark version. If a valid class name + * is not found, an error message is returned. * * @param dbVersion Databricks version. * @param sparkVersion Spark version. @@ -360,12 +358,9 @@ object ShuffleManagerResolver { def getClassName( dbVersion: Option[String], sparkVersion: Option[String]): Either[String, String] = { (dbVersion, sparkVersion) match { - case (Some(dbVer), _) => - getClassNameInternal(DatabricksVersionMap, dbVer) - case (None, Some(sparkVer)) => - getClassNameInternal(SparkVersionMap, sparkVer) - case _ => - Left(commentForMissingVersion) + case (Some(dbVer), _) => getClassNameInternal(supportedDatabricksVersionMap, dbVer) + case (None, Some(sparkVer)) => getClassNameInternal(supportedSparkVersionMap, sparkVer) + case _ => Left(commentForMissingVersion) } } } @@ -824,10 +819,8 @@ class AutoTuner( // TODO - do we do anything with 200 shuffle partitions or maybe if its close // set the Spark config spark.shuffle.sort.bypassMergeThreshold getShuffleManagerClassName match { - case Right(smClassName) => - appendRecommendation("spark.shuffle.manager", smClassName) - case Left(errMessage) => - appendComment(errMessage) + case Right(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName) + case Left(comment) => appendComment(comment) } appendComment(autoTunerConfigsProvider.classPathComments("rapids.shuffle.jars")) recommendFileCache() @@ -861,8 +854,8 @@ class AutoTuner( } /** - * Resolves the RapidsShuffleManager class name based on the Spark or Databricks version. - * If a valid class name is not found an error message is appended as a comment. + * Resolves the RapidsShuffleManager class name based on the Databricks or Spark version. + * If a valid class name is not found, an error message is returned. */ def getShuffleManagerClassName: Either[String, String] = { val dbVersion = getPropertyValue(DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY) diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala index af9a28a26..f6acc325e 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala @@ -2191,14 +2191,15 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." autoTuner: AutoTuner, expectedSmVersion: String): Unit = { autoTuner.getShuffleManagerClassName match { - case Right(smVersion) => - assert(smVersion == ShuffleManagerResolver.buildShuffleManagerClassName(expectedSmVersion)) + case Right(smClassName) => + assert(smClassName == + ShuffleManagerResolver.buildShuffleManagerClassName(expectedSmVersion)) case Left(comment) => fail(s"Expected valid RapidsShuffleManager but got comment: $comment") } } - test("test shuffle manager version for supported databricks") { + test("test shuffle manager version for supported databricks version") { val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), mutable.Map("spark.rapids.sql.enabled" -> "true", @@ -2213,7 +2214,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330db") } - test("test shuffle manager version for supported non-databricks") { + test("test shuffle manager version for supported spark version") { val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), mutable.Map("spark.rapids.sql.enabled" -> "true", @@ -2226,7 +2227,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330") } - test("test shuffle manager version for supported custom version") { + test("test shuffle manager version for supported custom spark version") { val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), mutable.Map("spark.rapids.sql.enabled" -> "true", @@ -2247,14 +2248,14 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." autoTuner: AutoTuner, sparkVersion: String): Unit = { autoTuner.getShuffleManagerClassName match { - case Right(smVersion) => - fail(s"Expected error comment but got valid RapidsShuffleManager with version $smVersion") + case Right(smClassName) => + fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName") case Left(comment) => assert(comment == ShuffleManagerResolver.commentForUnsupportedVersion(sparkVersion)) } } - test("test shuffle manager version for unsupported databricks") { + test("test shuffle manager version for unsupported databricks version") { val databricksVersion = "9.1.x-gpu-ml-scala2.12" val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), @@ -2269,7 +2270,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." verifyUnsupportedSparkVersionForShuffleManager(autoTuner, databricksVersion) } - test("test shuffle manager version for unsupported non-databricks") { + test("test shuffle manager version for unsupported spark version") { val sparkVersion = "3.1.2" val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), @@ -2282,7 +2283,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." verifyUnsupportedSparkVersionForShuffleManager(autoTuner, sparkVersion) } - test("test shuffle manager version for unsupported custom version") { + test("test shuffle manager version for unsupported custom spark version") { val customSparkVersion = "3.1.2-custom" val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), @@ -2306,8 +2307,8 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." infoProvider, PlatformFactory.createInstance()) // Verify that the shuffle manager is not recommended for missing Spark version autoTuner.getShuffleManagerClassName match { - case Right(smVersion) => - fail(s"Expected error comment but got valid RapidsShuffleManager with version $smVersion") + case Right(smClassName) => + fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName") case Left(comment) => assert(comment == ShuffleManagerResolver.commentForMissingVersion) }