Skip to content

Commit 052e909

Browse files
authored
Partially revert broadcast-join change (#20779)
- Supersedes #20749 - Partially reverts #20724 - Fully reverts the change for the "tasks" runtime - Partially reverts the change for "rapidsmpf" (we can use the actual data size to decide whether to concatenate during runtime) Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: #20779
1 parent 868ad85 commit 052e909

File tree

2 files changed

+106
-50
lines changed

2 files changed

+106
-50
lines changed

python/cudf_polars/cudf_polars/experimental/join.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cudf_polars.experimental.base import PartitionInfo, get_key_name
1313
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
1414
from cudf_polars.experimental.repartition import Repartition
15-
from cudf_polars.experimental.shuffle import Shuffle
15+
from cudf_polars.experimental.shuffle import Shuffle, _hash_partition_dataframe
1616
from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback
1717

1818
if TYPE_CHECKING:
@@ -344,36 +344,65 @@ def _(
344344
small_name = get_key_name(right)
345345
small_size = partition_info[right].count
346346
large_name = get_key_name(left)
347+
large_on = ir.left_on
347348
else:
348349
small_side = "Left"
349350
small_name = get_key_name(left)
350351
small_size = partition_info[left].count
351352
large_name = get_key_name(right)
353+
large_on = ir.right_on
352354

353355
graph: MutableMapping[Any, Any] = {}
354356

355357
out_name = get_key_name(ir)
356358
out_size = partition_info[ir].count
357-
concat_name = f"concat-{out_name}"
359+
split_name = f"split-{out_name}"
360+
getit_name = f"getit-{out_name}"
361+
inter_name = f"inter-{out_name}"
358362

359-
# Concatenate the small partitions
360-
if small_size > 1:
361-
graph[(concat_name, 0)] = (
362-
partial(_concat, context=context),
363-
*((small_name, j) for j in range(small_size)),
364-
)
365-
small_name = concat_name
363+
# Split each large partition if we have
364+
# multiple small partitions (unless this
365+
# is an inner join)
366+
split_large = ir.options[0] != "Inner" and small_size > 1
366367

367368
for part_out in range(out_size):
368-
join_children = [(large_name, part_out), (small_name, 0)]
369-
if small_side == "Left":
370-
join_children.reverse()
371-
graph[(out_name, part_out)] = (
372-
partial(ir.do_evaluate, context=context),
373-
ir.left_on,
374-
ir.right_on,
375-
ir.options,
376-
*join_children,
377-
)
369+
if split_large:
370+
graph[(split_name, part_out)] = (
371+
_hash_partition_dataframe,
372+
(large_name, part_out),
373+
part_out,
374+
small_size,
375+
None,
376+
large_on,
377+
)
378+
379+
_concat_list = []
380+
for j in range(small_size):
381+
left_key: tuple[str, int] | tuple[str, int, int]
382+
if split_large:
383+
left_key = (getit_name, part_out, j)
384+
graph[left_key] = (operator.getitem, (split_name, part_out), j)
385+
else:
386+
left_key = (large_name, part_out)
387+
join_children = [left_key, (small_name, j)]
388+
if small_side == "Left":
389+
join_children.reverse()
390+
391+
inter_key = (inter_name, part_out, j)
392+
graph[(inter_name, part_out, j)] = (
393+
partial(ir.do_evaluate, context=context),
394+
ir.left_on,
395+
ir.right_on,
396+
ir.options,
397+
*join_children,
398+
)
399+
_concat_list.append(inter_key)
400+
if len(_concat_list) == 1:
401+
graph[(out_name, part_out)] = graph.pop(_concat_list[0])
402+
else:
403+
graph[(out_name, part_out)] = (
404+
partial(_concat, context=context),
405+
*_concat_list,
406+
)
378407

379408
return graph

python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import asyncio
88
from typing import TYPE_CHECKING, Any, Literal
99

10+
from rapidsmpf.memory.buffer import MemoryType
1011
from rapidsmpf.streaming.core.message import Message
1112
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
1213

@@ -40,7 +41,7 @@ async def get_small_table(
4041
context: Context,
4142
small_child: IR,
4243
ch_small: ChannelPair,
43-
) -> list[DataFrame]:
44+
) -> tuple[list[DataFrame], int]:
4445
"""
4546
Get the small-table DataFrame partitions from the small-table ChannelPair.
4647
@@ -55,16 +56,17 @@ async def get_small_table(
5556
5657
Returns
5758
-------
58-
list[DataFrame]
59-
The small-table DataFrame partitions.
59+
The small-table DataFrame partitions and the size of the small-side in bytes.
6060
"""
6161
small_chunks = []
62+
small_size = 0
6263
while (msg := await ch_small.data.recv(context)) is not None:
6364
small_chunks.append(
6465
TableChunk.from_message(msg).make_available_and_spill(
6566
context.br(), allow_overbooking=True
6667
)
6768
)
69+
small_size += small_chunks[-1].data_alloc_size(MemoryType.DEVICE)
6870
assert small_chunks, "Empty small side"
6971

7072
return [
@@ -75,7 +77,7 @@ async def get_small_table(
7577
small_chunk.stream,
7678
)
7779
for small_chunk in small_chunks
78-
]
80+
], small_size
7981

8082

8183
@define_py_node()
@@ -88,6 +90,7 @@ async def broadcast_join_node(
8890
ch_right: ChannelPair,
8991
broadcast_side: Literal["left", "right"],
9092
collective_id: int,
93+
target_partition_size: int,
9194
) -> None:
9295
"""
9396
Join node for rapidsmpf.
@@ -108,8 +111,10 @@ async def broadcast_join_node(
108111
The right input ChannelPair.
109112
broadcast_side
110113
The side to broadcast.
111-
collective_id: int
114+
collective_id
112115
Pre-allocated collective ID for this operation.
116+
target_partition_size
117+
The target partition size in bytes.
113118
"""
114119
async with shutdown_on_error(
115120
context,
@@ -156,26 +161,33 @@ async def broadcast_join_node(
156161
await ch_out.send_metadata(context, output_metadata)
157162

158163
# Collect small-side
159-
small_df = _concat(
160-
*await get_small_table(context, small_child, small_ch),
161-
context=ir_context,
162-
)
163-
if context.comm().nranks > 1 and not small_duplicated:
164+
small_dfs, small_size = await get_small_table(context, small_child, small_ch)
165+
need_allgather = context.comm().nranks > 1 and not small_duplicated
166+
if (
167+
ir.options[0] != "Inner" or small_size < target_partition_size
168+
) and not need_allgather:
169+
# Pre-concat for non-inner joins, otherwise
170+
# we need a local shuffle, and face additional
171+
# memory pressure anyway.
172+
small_dfs = [_concat(*small_dfs, context=ir_context)]
173+
if need_allgather:
164174
allgather = AllGatherManager(context, collective_id)
165-
allgather.insert(
166-
0,
167-
TableChunk.from_pylibcudf_table(
168-
small_df.table, small_df.stream, exclusive_view=True
169-
),
170-
)
175+
for s_id, small_df in enumerate(small_dfs):
176+
allgather.insert(
177+
s_id,
178+
TableChunk.from_pylibcudf_table(
179+
small_df.table, small_df.stream, exclusive_view=True
180+
),
181+
)
171182
allgather.insert_finished()
172-
small_table = await allgather.extract_concatenated(small_df.stream)
173-
small_df = DataFrame.from_table(
174-
small_table,
175-
list(small_child.schema.keys()),
176-
list(small_child.schema.values()),
177-
small_df.stream,
178-
)
183+
small_dfs = [
184+
DataFrame.from_table(
185+
await allgather.extract_concatenated(small_df.stream),
186+
list(small_child.schema.keys()),
187+
list(small_child.schema.values()),
188+
small_df.stream,
189+
)
190+
]
179191

180192
# Stream through large side, joining with the small-side
181193
while (msg := await large_ch.data.recv(context)) is not None:
@@ -191,14 +203,22 @@ async def broadcast_join_node(
191203
)
192204

193205
# Perform the join
194-
df = await asyncio.to_thread(
195-
ir.do_evaluate,
196-
*ir._non_child_args,
197-
*(
198-
[large_df, small_df]
199-
if broadcast_side == "right"
200-
else [small_df, large_df]
201-
),
206+
df = _concat(
207+
*[
208+
(
209+
await asyncio.to_thread(
210+
ir.do_evaluate,
211+
*ir._non_child_args,
212+
*(
213+
[large_df, small_df]
214+
if broadcast_side == "right"
215+
else [small_df, large_df]
216+
),
217+
context=ir_context,
218+
)
219+
)
220+
for small_df in small_dfs
221+
],
202222
context=ir_context,
203223
)
204224

@@ -270,6 +290,12 @@ def _(
270290
else:
271291
broadcast_side = "left"
272292

293+
# Get target partition size
294+
config_options = rec.state["config_options"]
295+
executor = config_options.executor
296+
assert executor.name == "streaming", "Join node requires streaming executor"
297+
target_partition_size = executor.target_partition_size
298+
273299
nodes[ir] = [
274300
broadcast_join_node(
275301
rec.state["context"],
@@ -280,6 +306,7 @@ def _(
280306
channels[right].reserve_output_slot(),
281307
broadcast_side=broadcast_side,
282308
collective_id=rec.state["collective_id_map"][ir],
309+
target_partition_size=target_partition_size,
283310
)
284311
]
285312
return nodes, channels

0 commit comments

Comments
 (0)