Skip to content

Commit 637dba9

Browse files
authored
feat: Add CometBroadcastExchangeExec to support broadcasting the result of Comet native operator (apache#80)
1 parent 4bbc307 commit 637dba9

File tree

7 files changed

+331
-22
lines changed

7 files changed

+331
-22
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,15 @@ object CometConf {
139139
.booleanConf
140140
.createWithDefault(false)
141141

142+
val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] =
143+
conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled")
144+
.doc(
145+
"Whether to enable broadcasting for Comet native operators. By default, " +
146+
"this config is false. Note that this feature is not fully supported yet " +
147+
"and only enabled for test purpose.")
148+
.booleanConf
149+
.createWithDefault(false)
150+
142151
val COMET_EXEC_SHUFFLE_CODEC: ConfigEntry[String] = conf(
143152
s"$COMET_EXEC_CONFIG_PREFIX.shuffle.codec")
144153
.doc(

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ import org.apache.spark.sql.execution.datasources._
3737
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
3838
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
3939
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
40-
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
40+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
4141
import org.apache.spark.sql.internal.SQLConf
4242
import org.apache.spark.sql.types._
4343

4444
import org.apache.comet.CometConf._
45-
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
45+
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
4646
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
4747
import org.apache.comet.serde.OperatorOuterClass.Operator
4848
import org.apache.comet.serde.QueryPlanSerde
@@ -331,6 +331,16 @@ class CometSparkSessionExtensions
331331
u
332332
}
333333

334+
case b: BroadcastExchangeExec
335+
if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
336+
isCometBroadCastEnabled(conf) =>
337+
QueryPlanSerde.operator2Proto(b) match {
338+
case Some(nativeOp) =>
339+
val cometOp = CometBroadcastExchangeExec(b, b.child)
340+
CometSinkPlaceHolder(nativeOp, b, cometOp)
341+
case None => b
342+
}
343+
334344
// Native shuffle for Comet operators
335345
case s: ShuffleExchangeExec
336346
if isCometShuffleEnabled(conf) &&
@@ -482,6 +492,10 @@ object CometSparkSessionExtensions extends Logging {
482492
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
483493
}
484494

495+
private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
496+
COMET_EXEC_BROADCAST_ENABLED.get(conf)
497+
}
498+
485499
private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean =
486500
COMET_EXEC_SHUFFLE_ENABLED.get(conf) &&
487501
(conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") ==

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkP
3232
import org.apache.spark.sql.execution
3333
import org.apache.spark.sql.execution._
3434
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
35-
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
35+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.types._
3838
import org.apache.spark.unsafe.types.UTF8String
@@ -1802,6 +1802,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
18021802
case _: UnionExec => true
18031803
case _: ShuffleExchangeExec => true
18041804
case _: TakeOrderedAndProjectExec => true
1805+
case _: BroadcastExchangeExec => true
18051806
case _ => false
18061807
}
18071808
}
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet
21+
22+
import java.util.UUID
23+
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
24+
25+
import scala.concurrent.{ExecutionContext, Promise}
26+
import scala.concurrent.duration.NANOSECONDS
27+
import scala.util.control.NonFatal
28+
29+
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
30+
import org.apache.spark.rdd.RDD
31+
import org.apache.spark.sql.catalyst.InternalRow
32+
import org.apache.spark.sql.catalyst.plans.logical.Statistics
33+
import org.apache.spark.sql.errors.QueryExecutionErrors
34+
import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SparkPlan, SQLExecution}
35+
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
36+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
37+
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
38+
import org.apache.spark.sql.vectorized.ColumnarBatch
39+
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
40+
import org.apache.spark.util.io.ChunkedByteBuffer
41+
42+
import com.google.common.base.Objects
43+
44+
/**
45+
* A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a
46+
* transformed SparkPlan. This is a copy of the [[BroadcastExchangeExec]] class with the necessary
47+
* changes to support the Comet operator.
48+
*
49+
* [[CometBroadcastExchangeExec]] will be used in broadcast join operator.
50+
*
51+
* Note that this class cannot extend `CometExec` as usual similar to other Comet operators. As
52+
* the trait `BroadcastExchangeLike` in Spark extends abstract class `Exchange`, it limits the
53+
* flexibility to extend `CometExec` and `Exchange` at the same time.
54+
*/
55+
case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
56+
extends BroadcastExchangeLike {
57+
import CometBroadcastExchangeExec._
58+
59+
override val runId: UUID = UUID.randomUUID
60+
61+
override lazy val metrics: Map[String, SQLMetric] = Map(
62+
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
63+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
64+
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
65+
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"),
66+
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))
67+
68+
override def doCanonicalize(): SparkPlan = {
69+
CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized)
70+
}
71+
72+
override def runtimeStatistics: Statistics = {
73+
val dataSize = metrics("dataSize").value
74+
val rowCount = metrics("numOutputRows").value
75+
Statistics(dataSize, Some(rowCount))
76+
}
77+
78+
@transient
79+
private lazy val promise = Promise[broadcast.Broadcast[Any]]()
80+
81+
@transient
82+
override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
83+
promise.future
84+
85+
@transient
86+
private val timeout: Long = conf.broadcastTimeout
87+
88+
@transient
89+
private lazy val maxBroadcastRows = 512000000
90+
91+
@transient
92+
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
93+
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
94+
session,
95+
CometBroadcastExchangeExec.executionContext) {
96+
try {
97+
// Setup a job group here so later it may get cancelled by groupId if necessary.
98+
sparkContext.setJobGroup(
99+
runId.toString,
100+
s"broadcast exchange (runId $runId)",
101+
interruptOnCancel = true)
102+
val beforeCollect = System.nanoTime()
103+
104+
val countsAndBytes = child.asInstanceOf[CometExec].getByteArrayRdd().collect()
105+
val numRows = countsAndBytes.map(_._1).sum
106+
val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2)
107+
108+
longMetric("numOutputRows") += numRows
109+
if (numRows >= maxBroadcastRows) {
110+
throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(
111+
maxBroadcastRows,
112+
numRows)
113+
}
114+
115+
val beforeBuild = System.nanoTime()
116+
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
117+
118+
val batches = input.toArray
119+
120+
val dataSize = batches.map(_.size).sum
121+
122+
longMetric("dataSize") += dataSize
123+
if (dataSize >= MAX_BROADCAST_TABLE_BYTES) {
124+
throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(
125+
MAX_BROADCAST_TABLE_BYTES,
126+
dataSize)
127+
}
128+
129+
val beforeBroadcast = System.nanoTime()
130+
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
131+
132+
// SPARK-39983 - Broadcast the relation without caching the unserialized object.
133+
val broadcasted = sparkContext
134+
.broadcastInternal(batches, serializedOnly = true)
135+
.asInstanceOf[broadcast.Broadcast[Any]]
136+
longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast)
137+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
138+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
139+
promise.trySuccess(broadcasted)
140+
broadcasted
141+
} catch {
142+
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
143+
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
144+
// will catch this exception and re-throw the wrapped fatal throwable.
145+
case oe: OutOfMemoryError =>
146+
val tables = child.collect { case f: FileSourceScanExec => f.tableIdentifier }.flatten
147+
val ex = new SparkFatalException(
148+
QueryExecutionErrors.notEnoughMemoryToBuildAndBroadcastTableError(oe, tables))
149+
promise.tryFailure(ex)
150+
throw ex
151+
case e if !NonFatal(e) =>
152+
val ex = new SparkFatalException(e)
153+
promise.tryFailure(ex)
154+
throw ex
155+
case e: Throwable =>
156+
promise.tryFailure(e)
157+
throw e
158+
}
159+
}
160+
}
161+
162+
override protected def doPrepare(): Unit = {
163+
// Materialize the future.
164+
relationFuture
165+
}
166+
167+
override protected def doExecute(): RDD[InternalRow] = {
168+
throw QueryExecutionErrors.executeCodePathUnsupportedError("CometBroadcastExchangeExec")
169+
}
170+
171+
override def supportsColumnar: Boolean = true
172+
173+
// This is basically for unit test only.
174+
override def executeCollect(): Array[InternalRow] =
175+
ColumnarToRowExec(this).executeCollect()
176+
177+
// This is basically for unit test only, called by `executeCollect` indirectly.
178+
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
179+
val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()
180+
181+
new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted)
182+
}
183+
184+
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
185+
try {
186+
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
187+
} catch {
188+
case ex: TimeoutException =>
189+
logError(s"Could not execute broadcast in $timeout secs.", ex)
190+
if (!relationFuture.isDone) {
191+
sparkContext.cancelJobGroup(runId.toString)
192+
relationFuture.cancel(true)
193+
}
194+
throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
195+
}
196+
}
197+
198+
override def equals(obj: Any): Boolean = {
199+
obj match {
200+
case other: CometBroadcastExchangeExec =>
201+
this.originalPlan == other.originalPlan &&
202+
this.output == other.output && this.child == other.child
203+
case _ =>
204+
false
205+
}
206+
}
207+
208+
override def hashCode(): Int = Objects.hashCode(output, child)
209+
210+
override def stringArgs: Iterator[Any] = Iterator(output, child)
211+
212+
override protected def withNewChildInternal(newChild: SparkPlan): CometBroadcastExchangeExec =
213+
copy(child = newChild)
214+
}
215+
216+
object CometBroadcastExchangeExec {
217+
val MAX_BROADCAST_TABLE_BYTES: Long = 8L << 30
218+
219+
private[comet] val executionContext = ExecutionContext.fromExecutorService(
220+
ThreadUtils.newDaemonCachedThreadPool(
221+
"comet-broadcast-exchange",
222+
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
223+
}
224+
225+
/**
226+
* [[CometBatchRDD]] is a [[RDD]] of [[ColumnarBatch]]s that are broadcasted to the executors. It
227+
* is only used by [[CometBroadcastExchangeExec]] to broadcast the result of a Comet operator.
228+
*
229+
* @param sc
230+
* SparkContext
231+
* @param numPartitions
232+
* number of partitions
233+
* @param value
234+
* the broadcasted batches which are serialized into an array of [[ChunkedByteBuffer]]s
235+
*/
236+
class CometBatchRDD(
237+
sc: SparkContext,
238+
numPartitions: Int,
239+
value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
240+
extends RDD[ColumnarBatch](sc, Nil) {
241+
242+
override def getPartitions: Array[Partition] = (0 until numPartitions).toArray.map { i =>
243+
new CometBatchPartition(i, value)
244+
}
245+
246+
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
247+
val partition = split.asInstanceOf[CometBatchPartition]
248+
249+
partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
250+
}
251+
}
252+
253+
class CometBatchPartition(
254+
override val index: Int,
255+
val value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
256+
extends Partition {}

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ abstract class CometExec extends CometPlan {
7171
/**
7272
* Executes this Comet operator and serialized output ColumnarBatch into bytes.
7373
*/
74-
private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
74+
def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
7575
executeColumnar().mapPartitionsInternal { iter =>
7676
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
7777
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
@@ -85,28 +85,14 @@ abstract class CometExec extends CometPlan {
8585
}
8686
}
8787

88-
/**
89-
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
90-
*/
91-
private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
92-
if (bytes.size == 0) {
93-
return Iterator.empty
94-
}
95-
96-
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
97-
val cbbis = bytes.toInputStream()
98-
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
99-
100-
new ArrowReaderIterator(Channels.newChannel(ins))
101-
}
102-
10388
/**
10489
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
10590
*/
10691
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
10792
val countsAndBytes = getByteArrayRdd().collect()
10893
val total = countsAndBytes.map(_._1).sum
109-
val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2))
94+
val rows = countsAndBytes.iterator
95+
.flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2))
11096
(total, rows)
11197
}
11298
}
@@ -133,6 +119,21 @@ object CometExec {
133119
val bytes = outputStream.toByteArray
134120
new CometExecIterator(newIterId, inputs, bytes, nativeMetrics)
135121
}
122+
123+
/**
124+
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
125+
*/
126+
def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
127+
if (bytes.size == 0) {
128+
return Iterator.empty
129+
}
130+
131+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
132+
val cbbis = bytes.toInputStream()
133+
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
134+
135+
new ArrowReaderIterator(Channels.newChannel(ins))
136+
}
136137
}
137138

138139
/**

0 commit comments

Comments
 (0)