diff --git a/python/cudf_polars/cudf_polars/experimental/join.py b/python/cudf_polars/cudf_polars/experimental/join.py index 499786ffcaa..9cd06330be1 100644 --- a/python/cudf_polars/cudf_polars/experimental/join.py +++ b/python/cudf_polars/cudf_polars/experimental/join.py @@ -12,7 +12,7 @@ from cudf_polars.experimental.base import PartitionInfo, get_key_name from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node from cudf_polars.experimental.repartition import Repartition -from cudf_polars.experimental.shuffle import Shuffle +from cudf_polars.experimental.shuffle import Shuffle, _hash_partition_dataframe from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback if TYPE_CHECKING: @@ -344,36 +344,65 @@ def _( small_name = get_key_name(right) small_size = partition_info[right].count large_name = get_key_name(left) + large_on = ir.left_on else: small_side = "Left" small_name = get_key_name(left) small_size = partition_info[left].count large_name = get_key_name(right) + large_on = ir.right_on graph: MutableMapping[Any, Any] = {} out_name = get_key_name(ir) out_size = partition_info[ir].count - concat_name = f"concat-{out_name}" + split_name = f"split-{out_name}" + getit_name = f"getit-{out_name}" + inter_name = f"inter-{out_name}" - # Concatenate the small partitions - if small_size > 1: - graph[(concat_name, 0)] = ( - partial(_concat, context=context), - *((small_name, j) for j in range(small_size)), - ) - small_name = concat_name + # Split each large partition if we have + # multiple small partitions (unless this + # is an inner join) + split_large = ir.options[0] != "Inner" and small_size > 1 for part_out in range(out_size): - join_children = [(large_name, part_out), (small_name, 0)] - if small_side == "Left": - join_children.reverse() - graph[(out_name, part_out)] = ( - partial(ir.do_evaluate, context=context), - ir.left_on, - ir.right_on, - ir.options, - *join_children, - ) + if split_large: + graph[(split_name, part_out)] = ( + _hash_partition_dataframe, + (large_name, part_out), + part_out, + small_size, + None, + large_on, + ) + + _concat_list = [] + for j in range(small_size): + left_key: tuple[str, int] | tuple[str, int, int] + if split_large: + left_key = (getit_name, part_out, j) + graph[left_key] = (operator.getitem, (split_name, part_out), j) + else: + left_key = (large_name, part_out) + join_children = [left_key, (small_name, j)] + if small_side == "Left": + join_children.reverse() + + inter_key = (inter_name, part_out, j) + graph[(inter_name, part_out, j)] = ( + partial(ir.do_evaluate, context=context), + ir.left_on, + ir.right_on, + ir.options, + *join_children, + ) + _concat_list.append(inter_key) + if len(_concat_list) == 1: + graph[(out_name, part_out)] = graph.pop(_concat_list[0]) + else: + graph[(out_name, part_out)] = ( + partial(_concat, context=context), + *_concat_list, + ) return graph diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py index c03366a09ad..c971e5d003f 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py @@ -7,6 +7,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Literal +from rapidsmpf.memory.buffer import MemoryType from rapidsmpf.streaming.core.message import Message from rapidsmpf.streaming.cudf.table_chunk import TableChunk @@ -40,7 +41,7 @@ async def get_small_table( context: Context, small_child: IR, ch_small: ChannelPair, -) -> list[DataFrame]: +) -> tuple[list[DataFrame], int]: """ Get the small-table DataFrame partitions from the small-table ChannelPair. @@ -55,16 +56,17 @@ async def get_small_table( Returns ------- - list[DataFrame] - The small-table DataFrame partitions. + The small-table DataFrame partitions and the size of the small-side in bytes. """ small_chunks = [] + small_size = 0 while (msg := await ch_small.data.recv(context)) is not None: small_chunks.append( TableChunk.from_message(msg).make_available_and_spill( context.br(), allow_overbooking=True ) ) + small_size += small_chunks[-1].data_alloc_size(MemoryType.DEVICE) assert small_chunks, "Empty small side" return [ @@ -75,7 +77,7 @@ async def get_small_table( small_chunk.stream, ) for small_chunk in small_chunks - ] + ], small_size @define_py_node() @@ -88,6 +90,7 @@ async def broadcast_join_node( ch_right: ChannelPair, broadcast_side: Literal["left", "right"], collective_id: int, + target_partition_size: int, ) -> None: """ Join node for rapidsmpf. @@ -108,8 +111,10 @@ async def broadcast_join_node( The right input ChannelPair. broadcast_side The side to broadcast. - collective_id: int + collective_id Pre-allocated collective ID for this operation. + target_partition_size + The target partition size in bytes. """ async with shutdown_on_error( context, @@ -156,26 +161,33 @@ async def broadcast_join_node( await ch_out.send_metadata(context, output_metadata) # Collect small-side - small_df = _concat( - *await get_small_table(context, small_child, small_ch), - context=ir_context, - ) - if context.comm().nranks > 1 and not small_duplicated: + small_dfs, small_size = await get_small_table(context, small_child, small_ch) + need_allgather = context.comm().nranks > 1 and not small_duplicated + if ( + ir.options[0] != "Inner" or small_size < target_partition_size + ) and not need_allgather: + # Pre-concat for non-inner joins, otherwise + # we need a local shuffle, and face additional + # memory pressure anyway. + small_dfs = [_concat(*small_dfs, context=ir_context)] + if need_allgather: allgather = AllGatherManager(context, collective_id) - allgather.insert( - 0, - TableChunk.from_pylibcudf_table( - small_df.table, small_df.stream, exclusive_view=True - ), - ) + for s_id, small_df in enumerate(small_dfs): + allgather.insert( + s_id, + TableChunk.from_pylibcudf_table( + small_df.table, small_df.stream, exclusive_view=True + ), + ) allgather.insert_finished() - small_table = await allgather.extract_concatenated(small_df.stream) - small_df = DataFrame.from_table( - small_table, - list(small_child.schema.keys()), - list(small_child.schema.values()), - small_df.stream, - ) + small_dfs = [ + DataFrame.from_table( + await allgather.extract_concatenated(small_df.stream), + list(small_child.schema.keys()), + list(small_child.schema.values()), + small_df.stream, + ) + ] # Stream through large side, joining with the small-side while (msg := await large_ch.data.recv(context)) is not None: @@ -191,14 +203,22 @@ async def broadcast_join_node( ) # Perform the join - df = await asyncio.to_thread( - ir.do_evaluate, - *ir._non_child_args, - *( - [large_df, small_df] - if broadcast_side == "right" - else [small_df, large_df] - ), + df = _concat( + *[ + ( + await asyncio.to_thread( + ir.do_evaluate, + *ir._non_child_args, + *( + [large_df, small_df] + if broadcast_side == "right" + else [small_df, large_df] + ), + context=ir_context, + ) + ) + for small_df in small_dfs + ], context=ir_context, ) @@ -270,6 +290,12 @@ def _( else: broadcast_side = "left" + # Get target partition size + config_options = rec.state["config_options"] + executor = config_options.executor + assert executor.name == "streaming", "Join node requires streaming executor" + target_partition_size = executor.target_partition_size + nodes[ir] = [ broadcast_join_node( rec.state["context"], @@ -280,6 +306,7 @@ def _( channels[right].reserve_output_slot(), broadcast_side=broadcast_side, collective_id=rec.state["collective_id_map"][ir], + target_partition_size=target_partition_size, ) ] return nodes, channels