diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index c2994c97a411..bb29adde9776 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -35,6 +35,7 @@ import java.util.{Locale, ServiceLoader} import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable +import scala.collection.JavaConverters._ // scalastyle:off multiple.string.literals // scalastyle:off file.size.limit @@ -48,6 +49,7 @@ private[spark] object CosmosConfigNames { val TenantId = "spark.cosmos.account.tenantId" val ResourceGroupName = "spark.cosmos.account.resourceGroupName" val AzureEnvironment = "spark.cosmos.account.azureEnvironment" + val AzureEnvironmentEndpointOverride = "spark.cosmos.account.azureEnvironmentEndpointOverride" val AuthType = "spark.cosmos.auth.type" val ClientId = "spark.cosmos.auth.aad.clientId" val ResourceId = "spark.cosmos.auth.aad.resourceId" @@ -628,6 +630,21 @@ private object CosmosAccountConfig extends BasicLoggingTrait { }, helpMessage = "The azure environment of the CosmosDB account: `Azure`, `AzureChina`, `AzureUsGovernment`, `AzureGermany`.") + private val AzureEnvironmentOverrideEndpoints = CosmosConfigEntry[java.util.Map[String, String]](key = CosmosConfigNames.AzureEnvironmentEndpointOverride, + mandatory = false, + parseFromStringFunction = azureEnvironmentOverrideEndpointsAsCommaSeparatedString => { + val endpoints: java.util.Map[String, String] = + azureEnvironmentOverrideEndpointsAsCommaSeparatedString + .split(",") + .flatMap(_.split("=", 2) match { + case Array(k, v) => Some(k -> v) + case _ => throw new IllegalArgumentException(s"Azure environment override endpoint string $azureEnvironmentOverrideEndpointsAsCommaSeparatedString is not valid") + }).toMap + + endpoints + }, + helpMessage = "The azure environment endpoints to override. e.g. `portalUrl=https://example.com,managementEndpointUrl=https://other.com`") + private val ClientBuilderInterceptors = CosmosConfigEntry[String](key = CosmosConfigNames.ClientBuilderInterceptors, mandatory = false, parseFromStringFunction = clientBuilderInterceptorFQDN => clientBuilderInterceptorFQDN, @@ -670,6 +687,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val resourceGroupNameOpt = CosmosConfigEntry.parse(cfg, ResourceGroupName) val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val azureEnvironmentOpt = CosmosConfigEntry.parse(cfg, AzureEnvironmentTypeEnum) + val azureEnvironmentOverrideEndpointsOpt = CosmosConfigEntry.parse(cfg, AzureEnvironmentOverrideEndpoints) + val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) @@ -747,6 +766,19 @@ private object CosmosAccountConfig extends BasicLoggingTrait { } } + val azureEnvironmentEndpoints = azureEnvironmentOverrideEndpointsOpt match { + case Some(overrideMap) => + // Only override the keys that exist in the azureEnvironment endpoints map + azureEnvironmentOpt.get.asScala.map { + case (k, v) => + overrideMap.asScala.get(k) match { + case Some(overrideVal) if overrideVal.nonEmpty => k -> overrideVal + case _ => k -> v + } + }.asJava + case None => azureEnvironmentOpt.get + } + CosmosAccountConfig( endpointOpt.get, authConfig, @@ -762,7 +794,7 @@ private object CosmosAccountConfig extends BasicLoggingTrait { subscriptionIdOpt, tenantIdOpt, resourceGroupNameOpt, - azureEnvironmentOpt.get, + azureEnvironmentEndpoints, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) }