77import asyncio
88from typing import TYPE_CHECKING , Any , Literal
99
10+ from rapidsmpf .memory .buffer import MemoryType
1011from rapidsmpf .streaming .core .message import Message
1112from rapidsmpf .streaming .cudf .table_chunk import TableChunk
1213
@@ -40,7 +41,7 @@ async def get_small_table(
4041 context : Context ,
4142 small_child : IR ,
4243 ch_small : ChannelPair ,
43- ) -> list [DataFrame ]:
44+ ) -> tuple [ list [DataFrame ], int ]:
4445 """
4546 Get the small-table DataFrame partitions from the small-table ChannelPair.
4647
@@ -55,16 +56,17 @@ async def get_small_table(
5556
5657 Returns
5758 -------
58- list[DataFrame]
59- The small-table DataFrame partitions.
59+ The small-table DataFrame partitions and the size of the small-side in bytes.
6060 """
6161 small_chunks = []
62+ small_size = 0
6263 while (msg := await ch_small .data .recv (context )) is not None :
6364 small_chunks .append (
6465 TableChunk .from_message (msg ).make_available_and_spill (
6566 context .br (), allow_overbooking = True
6667 )
6768 )
69+ small_size += small_chunks [- 1 ].data_alloc_size (MemoryType .DEVICE )
6870 assert small_chunks , "Empty small side"
6971
7072 return [
@@ -75,7 +77,7 @@ async def get_small_table(
7577 small_chunk .stream ,
7678 )
7779 for small_chunk in small_chunks
78- ]
80+ ], small_size
7981
8082
8183@define_py_node ()
@@ -88,6 +90,7 @@ async def broadcast_join_node(
8890 ch_right : ChannelPair ,
8991 broadcast_side : Literal ["left" , "right" ],
9092 collective_id : int ,
93+ target_partition_size : int ,
9194) -> None :
9295 """
9396 Join node for rapidsmpf.
@@ -108,8 +111,10 @@ async def broadcast_join_node(
108111 The right input ChannelPair.
109112 broadcast_side
110113 The side to broadcast.
111- collective_id: int
114+ collective_id
112115 Pre-allocated collective ID for this operation.
116+ target_partition_size
117+ The target partition size in bytes.
113118 """
114119 async with shutdown_on_error (
115120 context ,
@@ -156,26 +161,33 @@ async def broadcast_join_node(
156161 await ch_out .send_metadata (context , output_metadata )
157162
158163 # Collect small-side
159- small_df = _concat (
160- * await get_small_table (context , small_child , small_ch ),
161- context = ir_context ,
162- )
163- if context .comm ().nranks > 1 and not small_duplicated :
164+ small_dfs , small_size = await get_small_table (context , small_child , small_ch )
165+ need_allgather = context .comm ().nranks > 1 and not small_duplicated
166+ if (
167+ ir .options [0 ] != "Inner" or small_size < target_partition_size
168+ ) and not need_allgather :
169+ # Pre-concat for non-inner joins, otherwise
170+ # we need a local shuffle, and face additional
171+ # memory pressure anyway.
172+ small_dfs = [_concat (* small_dfs , context = ir_context )]
173+ if need_allgather :
164174 allgather = AllGatherManager (context , collective_id )
165- allgather .insert (
166- 0 ,
167- TableChunk .from_pylibcudf_table (
168- small_df .table , small_df .stream , exclusive_view = True
169- ),
170- )
175+ for s_id , small_df in enumerate (small_dfs ):
176+ allgather .insert (
177+ s_id ,
178+ TableChunk .from_pylibcudf_table (
179+ small_df .table , small_df .stream , exclusive_view = True
180+ ),
181+ )
171182 allgather .insert_finished ()
172- small_table = await allgather .extract_concatenated (small_df .stream )
173- small_df = DataFrame .from_table (
174- small_table ,
175- list (small_child .schema .keys ()),
176- list (small_child .schema .values ()),
177- small_df .stream ,
178- )
183+ small_dfs = [
184+ DataFrame .from_table (
185+ await allgather .extract_concatenated (small_df .stream ),
186+ list (small_child .schema .keys ()),
187+ list (small_child .schema .values ()),
188+ small_df .stream ,
189+ )
190+ ]
179191
180192 # Stream through large side, joining with the small-side
181193 while (msg := await large_ch .data .recv (context )) is not None :
@@ -191,14 +203,22 @@ async def broadcast_join_node(
191203 )
192204
193205 # Perform the join
194- df = await asyncio .to_thread (
195- ir .do_evaluate ,
196- * ir ._non_child_args ,
197- * (
198- [large_df , small_df ]
199- if broadcast_side == "right"
200- else [small_df , large_df ]
201- ),
206+ df = _concat (
207+ * [
208+ (
209+ await asyncio .to_thread (
210+ ir .do_evaluate ,
211+ * ir ._non_child_args ,
212+ * (
213+ [large_df , small_df ]
214+ if broadcast_side == "right"
215+ else [small_df , large_df ]
216+ ),
217+ context = ir_context ,
218+ )
219+ )
220+ for small_df in small_dfs
221+ ],
202222 context = ir_context ,
203223 )
204224
@@ -270,6 +290,12 @@ def _(
270290 else :
271291 broadcast_side = "left"
272292
293+ # Get target partition size
294+ config_options = rec .state ["config_options" ]
295+ executor = config_options .executor
296+ assert executor .name == "streaming" , "Join node requires streaming executor"
297+ target_partition_size = executor .target_partition_size
298+
273299 nodes [ir ] = [
274300 broadcast_join_node (
275301 rec .state ["context" ],
@@ -280,6 +306,7 @@ def _(
280306 channels [right ].reserve_output_slot (),
281307 broadcast_side = broadcast_side ,
282308 collective_id = rec .state ["collective_id_map" ][ir ],
309+ target_partition_size = target_partition_size ,
283310 )
284311 ]
285312 return nodes , channels
0 commit comments