|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.spark.sql.comet |
| 21 | + |
| 22 | +import java.util.UUID |
| 23 | +import java.util.concurrent.{Future, TimeoutException, TimeUnit} |
| 24 | + |
| 25 | +import scala.concurrent.{ExecutionContext, Promise} |
| 26 | +import scala.concurrent.duration.NANOSECONDS |
| 27 | +import scala.util.control.NonFatal |
| 28 | + |
| 29 | +import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} |
| 30 | +import org.apache.spark.rdd.RDD |
| 31 | +import org.apache.spark.sql.catalyst.InternalRow |
| 32 | +import org.apache.spark.sql.catalyst.plans.logical.Statistics |
| 33 | +import org.apache.spark.sql.errors.QueryExecutionErrors |
| 34 | +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SparkPlan, SQLExecution} |
| 35 | +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike |
| 36 | +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} |
| 37 | +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} |
| 38 | +import org.apache.spark.sql.vectorized.ColumnarBatch |
| 39 | +import org.apache.spark.util.{SparkFatalException, ThreadUtils} |
| 40 | +import org.apache.spark.util.io.ChunkedByteBuffer |
| 41 | + |
| 42 | +import com.google.common.base.Objects |
| 43 | + |
| 44 | +/** |
| 45 | + * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a |
| 46 | + * transformed SparkPlan. This is a copy of the [[BroadcastExchangeExec]] class with the necessary |
| 47 | + * changes to support the Comet operator. |
| 48 | + * |
| 49 | + * [[CometBroadcastExchangeExec]] will be used in broadcast join operator. |
| 50 | + * |
| 51 | + * Note that this class cannot extend `CometExec` as usual similar to other Comet operators. As |
| 52 | + * the trait `BroadcastExchangeLike` in Spark extends abstract class `Exchange`, it limits the |
| 53 | + * flexibility to extend `CometExec` and `Exchange` at the same time. |
| 54 | + */ |
| 55 | +case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan) |
| 56 | + extends BroadcastExchangeLike { |
| 57 | + import CometBroadcastExchangeExec._ |
| 58 | + |
| 59 | + override val runId: UUID = UUID.randomUUID |
| 60 | + |
| 61 | + override lazy val metrics: Map[String, SQLMetric] = Map( |
| 62 | + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), |
| 63 | + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), |
| 64 | + "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"), |
| 65 | + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"), |
| 66 | + "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) |
| 67 | + |
| 68 | + override def doCanonicalize(): SparkPlan = { |
| 69 | + CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized) |
| 70 | + } |
| 71 | + |
| 72 | + override def runtimeStatistics: Statistics = { |
| 73 | + val dataSize = metrics("dataSize").value |
| 74 | + val rowCount = metrics("numOutputRows").value |
| 75 | + Statistics(dataSize, Some(rowCount)) |
| 76 | + } |
| 77 | + |
| 78 | + @transient |
| 79 | + private lazy val promise = Promise[broadcast.Broadcast[Any]]() |
| 80 | + |
| 81 | + @transient |
| 82 | + override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = |
| 83 | + promise.future |
| 84 | + |
| 85 | + @transient |
| 86 | + private val timeout: Long = conf.broadcastTimeout |
| 87 | + |
| 88 | + @transient |
| 89 | + private lazy val maxBroadcastRows = 512000000 |
| 90 | + |
| 91 | + @transient |
| 92 | + override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { |
| 93 | + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( |
| 94 | + session, |
| 95 | + CometBroadcastExchangeExec.executionContext) { |
| 96 | + try { |
| 97 | + // Setup a job group here so later it may get cancelled by groupId if necessary. |
| 98 | + sparkContext.setJobGroup( |
| 99 | + runId.toString, |
| 100 | + s"broadcast exchange (runId $runId)", |
| 101 | + interruptOnCancel = true) |
| 102 | + val beforeCollect = System.nanoTime() |
| 103 | + |
| 104 | + val countsAndBytes = child.asInstanceOf[CometExec].getByteArrayRdd().collect() |
| 105 | + val numRows = countsAndBytes.map(_._1).sum |
| 106 | + val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2) |
| 107 | + |
| 108 | + longMetric("numOutputRows") += numRows |
| 109 | + if (numRows >= maxBroadcastRows) { |
| 110 | + throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError( |
| 111 | + maxBroadcastRows, |
| 112 | + numRows) |
| 113 | + } |
| 114 | + |
| 115 | + val beforeBuild = System.nanoTime() |
| 116 | + longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) |
| 117 | + |
| 118 | + val batches = input.toArray |
| 119 | + |
| 120 | + val dataSize = batches.map(_.size).sum |
| 121 | + |
| 122 | + longMetric("dataSize") += dataSize |
| 123 | + if (dataSize >= MAX_BROADCAST_TABLE_BYTES) { |
| 124 | + throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError( |
| 125 | + MAX_BROADCAST_TABLE_BYTES, |
| 126 | + dataSize) |
| 127 | + } |
| 128 | + |
| 129 | + val beforeBroadcast = System.nanoTime() |
| 130 | + longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) |
| 131 | + |
| 132 | + // SPARK-39983 - Broadcast the relation without caching the unserialized object. |
| 133 | + val broadcasted = sparkContext |
| 134 | + .broadcastInternal(batches, serializedOnly = true) |
| 135 | + .asInstanceOf[broadcast.Broadcast[Any]] |
| 136 | + longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast) |
| 137 | + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) |
| 138 | + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) |
| 139 | + promise.trySuccess(broadcasted) |
| 140 | + broadcasted |
| 141 | + } catch { |
| 142 | + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw |
| 143 | + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult |
| 144 | + // will catch this exception and re-throw the wrapped fatal throwable. |
| 145 | + case oe: OutOfMemoryError => |
| 146 | + val tables = child.collect { case f: FileSourceScanExec => f.tableIdentifier }.flatten |
| 147 | + val ex = new SparkFatalException( |
| 148 | + QueryExecutionErrors.notEnoughMemoryToBuildAndBroadcastTableError(oe, tables)) |
| 149 | + promise.tryFailure(ex) |
| 150 | + throw ex |
| 151 | + case e if !NonFatal(e) => |
| 152 | + val ex = new SparkFatalException(e) |
| 153 | + promise.tryFailure(ex) |
| 154 | + throw ex |
| 155 | + case e: Throwable => |
| 156 | + promise.tryFailure(e) |
| 157 | + throw e |
| 158 | + } |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + override protected def doPrepare(): Unit = { |
| 163 | + // Materialize the future. |
| 164 | + relationFuture |
| 165 | + } |
| 166 | + |
| 167 | + override protected def doExecute(): RDD[InternalRow] = { |
| 168 | + throw QueryExecutionErrors.executeCodePathUnsupportedError("CometBroadcastExchangeExec") |
| 169 | + } |
| 170 | + |
| 171 | + override def supportsColumnar: Boolean = true |
| 172 | + |
| 173 | + // This is basically for unit test only. |
| 174 | + override def executeCollect(): Array[InternalRow] = |
| 175 | + ColumnarToRowExec(this).executeCollect() |
| 176 | + |
| 177 | + // This is basically for unit test only, called by `executeCollect` indirectly. |
| 178 | + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { |
| 179 | + val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]() |
| 180 | + |
| 181 | + new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted) |
| 182 | + } |
| 183 | + |
| 184 | + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { |
| 185 | + try { |
| 186 | + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] |
| 187 | + } catch { |
| 188 | + case ex: TimeoutException => |
| 189 | + logError(s"Could not execute broadcast in $timeout secs.", ex) |
| 190 | + if (!relationFuture.isDone) { |
| 191 | + sparkContext.cancelJobGroup(runId.toString) |
| 192 | + relationFuture.cancel(true) |
| 193 | + } |
| 194 | + throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | + override def equals(obj: Any): Boolean = { |
| 199 | + obj match { |
| 200 | + case other: CometBroadcastExchangeExec => |
| 201 | + this.originalPlan == other.originalPlan && |
| 202 | + this.output == other.output && this.child == other.child |
| 203 | + case _ => |
| 204 | + false |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + override def hashCode(): Int = Objects.hashCode(output, child) |
| 209 | + |
| 210 | + override def stringArgs: Iterator[Any] = Iterator(output, child) |
| 211 | + |
| 212 | + override protected def withNewChildInternal(newChild: SparkPlan): CometBroadcastExchangeExec = |
| 213 | + copy(child = newChild) |
| 214 | +} |
| 215 | + |
| 216 | +object CometBroadcastExchangeExec { |
| 217 | + val MAX_BROADCAST_TABLE_BYTES: Long = 8L << 30 |
| 218 | + |
| 219 | + private[comet] val executionContext = ExecutionContext.fromExecutorService( |
| 220 | + ThreadUtils.newDaemonCachedThreadPool( |
| 221 | + "comet-broadcast-exchange", |
| 222 | + SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD))) |
| 223 | +} |
| 224 | + |
| 225 | +/** |
| 226 | + * [[CometBatchRDD]] is a [[RDD]] of [[ColumnarBatch]]s that are broadcasted to the executors. It |
| 227 | + * is only used by [[CometBroadcastExchangeExec]] to broadcast the result of a Comet operator. |
| 228 | + * |
| 229 | + * @param sc |
| 230 | + * SparkContext |
| 231 | + * @param numPartitions |
| 232 | + * number of partitions |
| 233 | + * @param value |
| 234 | + * the broadcasted batches which are serialized into an array of [[ChunkedByteBuffer]]s |
| 235 | + */ |
| 236 | +class CometBatchRDD( |
| 237 | + sc: SparkContext, |
| 238 | + numPartitions: Int, |
| 239 | + value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) |
| 240 | + extends RDD[ColumnarBatch](sc, Nil) { |
| 241 | + |
| 242 | + override def getPartitions: Array[Partition] = (0 until numPartitions).toArray.map { i => |
| 243 | + new CometBatchPartition(i, value) |
| 244 | + } |
| 245 | + |
| 246 | + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { |
| 247 | + val partition = split.asInstanceOf[CometBatchPartition] |
| 248 | + |
| 249 | + partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator |
| 250 | + } |
| 251 | +} |
| 252 | + |
| 253 | +class CometBatchPartition( |
| 254 | + override val index: Int, |
| 255 | + val value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) |
| 256 | + extends Partition {} |
0 commit comments