Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions python/cudf_polars/cudf_polars/experimental/join.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All changes to this file are directly reverting #20724

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cudf_polars.experimental.base import PartitionInfo, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
from cudf_polars.experimental.repartition import Repartition
from cudf_polars.experimental.shuffle import Shuffle
from cudf_polars.experimental.shuffle import Shuffle, _hash_partition_dataframe
from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback

if TYPE_CHECKING:
Expand Down Expand Up @@ -344,36 +344,65 @@ def _(
small_name = get_key_name(right)
small_size = partition_info[right].count
large_name = get_key_name(left)
large_on = ir.left_on
else:
small_side = "Left"
small_name = get_key_name(left)
small_size = partition_info[left].count
large_name = get_key_name(right)
large_on = ir.right_on

graph: MutableMapping[Any, Any] = {}

out_name = get_key_name(ir)
out_size = partition_info[ir].count
concat_name = f"concat-{out_name}"
split_name = f"split-{out_name}"
getit_name = f"getit-{out_name}"
inter_name = f"inter-{out_name}"

# Concatenate the small partitions
if small_size > 1:
graph[(concat_name, 0)] = (
partial(_concat, context=context),
*((small_name, j) for j in range(small_size)),
)
small_name = concat_name
# Split each large partition if we have
# multiple small partitions (unless this
# is an inner join)
split_large = ir.options[0] != "Inner" and small_size > 1

for part_out in range(out_size):
join_children = [(large_name, part_out), (small_name, 0)]
if small_side == "Left":
join_children.reverse()
graph[(out_name, part_out)] = (
partial(ir.do_evaluate, context=context),
ir.left_on,
ir.right_on,
ir.options,
*join_children,
)
if split_large:
graph[(split_name, part_out)] = (
_hash_partition_dataframe,
(large_name, part_out),
part_out,
small_size,
None,
large_on,
)

_concat_list = []
for j in range(small_size):
left_key: tuple[str, int] | tuple[str, int, int]
if split_large:
left_key = (getit_name, part_out, j)
graph[left_key] = (operator.getitem, (split_name, part_out), j)
else:
left_key = (large_name, part_out)
join_children = [left_key, (small_name, j)]
if small_side == "Left":
join_children.reverse()

inter_key = (inter_name, part_out, j)
graph[(inter_name, part_out, j)] = (
partial(ir.do_evaluate, context=context),
ir.left_on,
ir.right_on,
ir.options,
*join_children,
)
_concat_list.append(inter_key)
if len(_concat_list) == 1:
graph[(out_name, part_out)] = graph.pop(_concat_list[0])
else:
graph[(out_name, part_out)] = (
partial(_concat, context=context),
*_concat_list,
)

return graph
90 changes: 59 additions & 31 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
from typing import TYPE_CHECKING, Any, Literal

from rapidsmpf.memory.buffer import MemoryType
from rapidsmpf.streaming.core.message import Message
from rapidsmpf.streaming.cudf.table_chunk import TableChunk

Expand Down Expand Up @@ -40,7 +41,7 @@ async def get_small_table(
context: Context,
small_child: IR,
ch_small: ChannelPair,
) -> list[DataFrame]:
) -> tuple[list[DataFrame], int]:
"""
Get the small-table DataFrame partitions from the small-table ChannelPair.

Expand All @@ -55,16 +56,17 @@ async def get_small_table(

Returns
-------
list[DataFrame]
The small-table DataFrame partitions.
The small-table DataFrame partitions and the size of the small-side in bytes.
"""
small_chunks = []
small_size = 0
while (msg := await ch_small.data.recv(context)) is not None:
small_chunks.append(
TableChunk.from_message(msg).make_available_and_spill(
context.br(), allow_overbooking=True
)
)
small_size += small_chunks[-1].data_alloc_size(MemoryType.DEVICE)
assert small_chunks, "Empty small side"

return [
Expand All @@ -75,7 +77,7 @@ async def get_small_table(
small_chunk.stream,
)
for small_chunk in small_chunks
]
], small_size


@define_py_node()
Expand All @@ -88,6 +90,7 @@ async def broadcast_join_node(
ch_right: ChannelPair,
broadcast_side: Literal["left", "right"],
collective_id: int,
target_partition_size: int,
) -> None:
"""
Join node for rapidsmpf.
Expand All @@ -108,8 +111,10 @@ async def broadcast_join_node(
The right input ChannelPair.
broadcast_side
The side to broadcast.
collective_id: int
collective_id
Pre-allocated collective ID for this operation.
target_partition_size
The target partition size in bytes.
"""
async with shutdown_on_error(
context,
Expand Down Expand Up @@ -156,26 +161,34 @@ async def broadcast_join_node(
await ch_out.send_metadata(context, output_metadata)

# Collect small-side
small_df = _concat(
*await get_small_table(context, small_child, small_ch),
context=ir_context,
)
if context.comm().nranks > 1 and not small_duplicated:
small_dfs, small_size = await get_small_table(context, small_child, small_ch)
need_allgather = context.comm().nranks > 1 and not small_duplicated
if (
ir.options[0] != "Inner" or small_size < target_partition_size
) and not need_allgather:
# Pre-concat for non-inner joins, otherwise
# we need a local shuffle, and face additional
# memory pressure anyway.
small_dfs = [_concat(*small_dfs, context=ir_context)]
Comment on lines -159 to +172
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before #20724, we were only pre-concatenating for non-inner joins (to avoid the extra local shuffle).

After #20724, we were always pre-concatenating.

Now, we pre-concatenate when it makes sense:

  • To avoid the local shuffle (e.g. non-inner joins)
  • We have small data
  • We aren't already concatenating via an allgather

if need_allgather:
allgather = AllGatherManager(context, collective_id)
allgather.insert(
0,
TableChunk.from_pylibcudf_table(
small_df.table, small_df.stream, exclusive_view=True
),
)
for s_id in range(len(small_dfs)):
small_df = small_dfs.pop()
allgather.insert(
s_id,
TableChunk.from_pylibcudf_table(
small_df.table, small_df.stream, exclusive_view=True
),
)
allgather.insert_finished()
small_table = await allgather.extract_concatenated(small_df.stream)
small_df = DataFrame.from_table(
small_table,
list(small_child.schema.keys()),
list(small_child.schema.values()),
small_df.stream,
)
small_dfs = [
DataFrame.from_table(
await allgather.extract_concatenated(small_df.stream),
list(small_child.schema.keys()),
list(small_child.schema.values()),
small_df.stream,
)
]

# Stream through large side, joining with the small-side
while (msg := await large_ch.data.recv(context)) is not None:
Expand All @@ -191,14 +204,22 @@ async def broadcast_join_node(
)

# Perform the join
df = await asyncio.to_thread(
ir.do_evaluate,
*ir._non_child_args,
*(
[large_df, small_df]
if broadcast_side == "right"
else [small_df, large_df]
),
df = _concat(
*[
(
await asyncio.to_thread(
ir.do_evaluate,
*ir._non_child_args,
*(
[large_df, small_df]
if broadcast_side == "right"
else [small_df, large_df]
),
context=ir_context,
)
)
for small_df in small_dfs
],
context=ir_context,
)

Expand Down Expand Up @@ -270,6 +291,12 @@ def _(
else:
broadcast_side = "left"

# Get target partition size
config_options = rec.state["config_options"]
executor = config_options.executor
assert executor.name == "streaming", "Join node requires streaming executor"
target_partition_size = executor.target_partition_size

nodes[ir] = [
broadcast_join_node(
rec.state["context"],
Expand All @@ -280,6 +307,7 @@ def _(
channels[right].reserve_output_slot(),
broadcast_side=broadcast_side,
collective_id=rec.state["collective_id_map"][ir],
target_partition_size=target_partition_size,
)
]
return nodes, channels
Loading