diff --git a/src/coll_patterns/recursive_knomial.h b/src/coll_patterns/recursive_knomial.h index 1888169f8b..5935fbddb6 100644 --- a/src/coll_patterns/recursive_knomial.h +++ b/src/coll_patterns/recursive_knomial.h @@ -23,6 +23,8 @@ enum { KN_PATTERN_ALLGATHER, KN_PATTERN_ALLGATHERV, KN_PATTERN_ALLGATHERX, + KN_PATTERN_GATHER, + KN_PATTERN_GATHERX, }; typedef struct ucc_knomial_pattern { @@ -83,7 +85,7 @@ static inline ucc_rank_t ucc_kn_pattern_radix_pow_init(ucc_knomial_pattern_t *p, static inline void ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, ucc_knomial_pattern_t *p, - int backward) + int backward, int has_extra) { ucc_rank_t fs = radix; ucc_rank_t n_full_subtrees; @@ -100,7 +102,7 @@ ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank, p->backward = backward; p->iteration = 0; n_full_subtrees = ucc_kn_pattern_n_full(p); - p->n_extra = size - n_full_subtrees * p->full_pow_size; + p->n_extra = has_extra ? size - n_full_subtrees * p->full_pow_size : 0; p->n_iters = (p->n_extra && n_full_subtrees == 1) ? p->pow_radix_sup - 1 : p->pow_radix_sup; p->radix_pow = ucc_kn_pattern_radix_pow_init(p, backward); @@ -115,14 +117,22 @@ ucc_knomial_pattern_init_backward(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, ucc_knomial_pattern_t *p) { - ucc_knomial_pattern_init_impl(size, rank, radix, p, 1); + ucc_knomial_pattern_init_impl(size, rank, radix, p, 1, 1); } static inline void ucc_knomial_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, ucc_knomial_pattern_t *p) { - ucc_knomial_pattern_init_impl(size, rank, radix, p, 0); + ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 1); +} + +static inline void +ucc_knomial_pattern_init_no_extra(ucc_rank_t size, ucc_rank_t rank, + ucc_kn_radix_t radix, + ucc_knomial_pattern_t *p) +{ + ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 0); } static inline ucc_rank_t @@ -186,6 +196,23 @@ ucc_knomial_pattern_get_loop_peer(ucc_knomial_pattern_t *p, ucc_rank_t rank, ucc_knomial_pattern_loop_rank_inv(p, peer); } +static inline ucc_rank_t +ucc_knomial_pattern_get_base_rank(ucc_knomial_pattern_t *p, ucc_rank_t rank) +{ + ucc_rank_t step_size = p->radix_pow * p->radix; + ucc_rank_t lrank; + ucc_kn_radix_t s; + + lrank = ucc_knomial_pattern_loop_rank(p, rank); + s = ucc_div_round_up(step_size - (lrank % step_size), p->radix_pow); + + if (s == p->radix) { + return rank; + } else { + return ucc_knomial_pattern_get_loop_peer(p, rank, s); + } +} + static inline void ucc_knomial_pattern_next_iteration(ucc_knomial_pattern_t *p) { @@ -224,11 +251,13 @@ static inline ucc_rank_t ucc_knomial_calc_recv_dist(ucc_rank_t team_size, ucc_rank_t rank, ucc_rank_t radix, ucc_rank_t root) { + ucc_rank_t root_base = 0; + ucc_rank_t dist = 1; + if (rank == root) { return 0; } - ucc_rank_t root_base = 0 ; - ucc_rank_t dist = 1; + while (dist <= team_size) { if (rank < root_base + radix * dist) { break; diff --git a/src/coll_patterns/sra_knomial.h b/src/coll_patterns/sra_knomial.h index 11b99dcf53..de4e45a8f8 100644 --- a/src/coll_patterns/sra_knomial.h +++ b/src/coll_patterns/sra_knomial.h @@ -184,6 +184,82 @@ ucc_knx_block(ucc_rank_t rank, ucc_rank_t size, ucc_kn_radix_t radix, *b_offset = offset; } +static inline void +ucc_kn_g_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, + size_t count, ucc_knomial_pattern_t *p) +{ + ucc_knomial_pattern_init_no_extra(size, rank, radix, p); + p->type = KN_PATTERN_GATHER; + p->count = count; + p->block_size = p->radix_pow * radix; + p->block_offset = ucc_knomial_pattern_loop_rank(p, rank) / p->block_size * + p->block_size; +} + +static inline void +ucc_kn_gx_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, + size_t count, ucc_knomial_pattern_t *p) +{ + ucc_knomial_pattern_init_backward(size, rank, radix, p); + p->type = KN_PATTERN_GATHERX; + p->count = count; + if (p->node_type != KN_NODE_EXTRA) { + p->block_size = ucc_kn_compute_step_radix(p); + ucc_knx_block(rank, size, radix, count, p->n_iters - 1, + &p->block_size_counts, &p->block_offset); + + } + +} + +static inline void +ucc_kn_g_pattern_peer_seg(ucc_rank_t peer, ucc_knomial_pattern_t *p, + size_t *seg_count, ptrdiff_t *seg_offset) +{ + ucc_rank_t step_radix, seg_index; + + *seg_count = 0; + *seg_offset = 0; + switch (p->type) { + case KN_PATTERN_GATHER: + *seg_count = ucc_min(p->radix_pow, p->size - peer) * (p->count / p->size); + *seg_offset = peer * (p->count / p->size); + return; + case KN_PATTERN_GATHERX: + step_radix = ucc_kn_compute_step_radix(p); + seg_index = ucc_kn_compute_seg_index(peer, p->radix_pow, p); + *seg_offset = ucc_buffer_block_offset(p->block_size_counts, step_radix, + seg_index) + p->block_offset; + *seg_count = ucc_buffer_block_count(p->block_size_counts, step_radix, + seg_index); + return; + default: + ucc_assert(0); + } +} + +static inline void ucc_kn_g_pattern_next_iter(ucc_knomial_pattern_t *p) +{ + ucc_rank_t rank; + if (p->type == KN_PATTERN_GATHERX) { + ucc_knomial_pattern_next_iteration_backward(p); + + if (!ucc_knomial_pattern_loop_done(p)) { + ucc_knx_block(p->rank, p->size, p->radix, p->count, + p->n_iters - 1 - p->iteration, + &p->block_size_counts, &p->block_offset); + } + } else { + rank = ucc_knomial_pattern_loop_rank(p, p->rank); + ucc_knomial_pattern_next_iteration(p); + + if (!ucc_knomial_pattern_loop_done(p)) { + p->block_size *= ucc_kn_compute_step_radix(p); + p->block_offset = rank / p->block_size * p->block_size; + } + } +} + static inline void ucc_kn_ag_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, size_t count, ucc_knomial_pattern_t *p) diff --git a/src/components/tl/ucp/gather/gather.c b/src/components/tl/ucp/gather/gather.c index d4748a8025..32b24f85f4 100644 --- a/src/components/tl/ucp/gather/gather.c +++ b/src/components/tl/ucp/gather/gather.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -17,62 +17,13 @@ ucc_base_coll_alg_info_t [UCC_TL_UCP_GATHER_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; -static inline uint32_t calc_buffer_size(ucc_rank_t rank, uint32_t radix, ucc_rank_t team_size) -{ - uint32_t radix_valuation; - - if (rank == 0) { - return team_size; - } - radix_valuation = calc_valuation(rank, radix); - return (uint32_t)ucc_min(pow(radix, radix_valuation), team_size - rank); -} - ucc_status_t ucc_tl_ucp_gather_init(ucc_tl_ucp_task_t *task) { - ucc_coll_args_t * args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t myrank = UCC_TL_TEAM_RANK(team); - ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team); - ucc_rank_t root = args->root; - ucc_rank_t vrank = (myrank - root + team_size) % team_size; - ucc_status_t status = UCC_OK; - ucc_memory_type_t mtype; - ucc_datatype_t dt; - size_t count, data_size; - uint32_t buffer_size; - int isleaf; - - if (root == myrank) { - count = args->dst.info.count; - dt = args->dst.info.datatype; - mtype = args->dst.info.mem_type; - } else { - count = args->src.info.count; - dt = args->src.info.datatype; - mtype = args->src.info.mem_type; - } - data_size = count * ucc_dt_size(dt); - task->super.post = ucc_tl_ucp_gather_knomial_start; - task->super.progress = ucc_tl_ucp_gather_knomial_progress; - task->super.finalize = ucc_tl_ucp_gather_knomial_finalize; - task->gather_kn.radix = - ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, team_size); - CALC_KN_TREE_DIST(team_size, task->gather_kn.radix, - task->gather_kn.max_dist); - isleaf = (vrank % task->gather_kn.radix != 0 || vrank == team_size - 1); - task->gather_kn.scratch_mc_header = NULL; + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t size = UCC_TL_TEAM_SIZE(team); + ucc_kn_radix_t radix; - if (vrank == 0) { - task->gather_kn.scratch = args->dst.info.buffer; - } else if (isleaf) { - task->gather_kn.scratch = args->src.info.buffer; - } else { - buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, team_size); - status = ucc_mc_alloc(&task->gather_kn.scratch_mc_header, - buffer_size * data_size, mtype); - task->gather_kn.scratch = task->gather_kn.scratch_mc_header->addr; - } + radix = ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, size); - return status; + return ucc_tl_ucp_gather_knomial_init_common(task, radix); } diff --git a/src/components/tl/ucp/gather/gather.h b/src/components/tl/ucp/gather/gather.h index 26a3df4138..78ce0e7420 100644 --- a/src/components/tl/ucp/gather/gather.h +++ b/src/components/tl/ucp/gather/gather.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -45,4 +45,12 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_gather_knomial_finalize(ucc_coll_task_t *task); +ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task, + ucc_kn_radix_t radix); + +/* Internal interface with custom radix */ +ucc_status_t ucc_tl_ucp_gather_knomial_init_r(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h, + ucc_kn_radix_t radix); #endif diff --git a/src/components/tl/ucp/gather/gather_knomial.c b/src/components/tl/ucp/gather/gather_knomial.c index ab6ff8933e..7052540306 100644 --- a/src/components/tl/ucp/gather/gather_knomial.c +++ b/src/components/tl/ucp/gather/gather_knomial.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -9,135 +9,197 @@ #include "core/ucc_progress_queue.h" #include "tl_ucp_sendrecv.h" #include "utils/ucc_math.h" +#include "coll_patterns/sra_knomial.h" #define SAVE_STATE(_phase) \ do { \ task->gather_kn.phase = _phase; \ } while (0) +static inline uint32_t calc_buffer_size(ucc_rank_t vrank, uint32_t radix, + ucc_rank_t tsize) +{ + uint32_t radix_valuation; + + if (vrank == 0) { + return tsize; + } + + radix_valuation = calc_valuation(vrank, radix); + return (uint32_t)ucc_min(pow(radix, radix_valuation), tsize - vrank); +} + +/* gather knomial is used as regular gather collective and as part of reduce SRG */ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_coll_args_t * args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team); - ucc_rank_t rank = UCC_TL_TEAM_RANK(team); - ucc_rank_t size = UCC_TL_TEAM_SIZE(team); - ucc_rank_t root = (ucc_rank_t)args->root; - uint32_t radix = task->gather_kn.radix; - ucc_rank_t vrank = (rank - root + size) % size; - ucc_memory_type_t mtype = args->src.info.mem_type; - ucc_status_t status = UCC_OK; - size_t data_size = - args->src.info.count * ucc_dt_size(args->src.info.datatype); - size_t msg_size, msg_count; - void * scratch_offset; - ucc_rank_t vpeer, peer, vroot_at_level, root_at_level, pos; - uint32_t i; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + ucc_rank_t root = (ucc_rank_t)args->root; + uint32_t radix = task->gather_kn.radix; + ucc_rank_t vrank = VRANK(rank, root, tsize); + ucc_memory_type_t mtype = args->src.info.mem_type; + ucc_status_t status = UCC_OK; + ucc_knomial_pattern_t *p = &task->gather_kn.p; + size_t dt_size = ucc_dt_size(args->src.info.datatype); + size_t data_size = args->src.info.count * dt_size; + ucc_coll_type_t ct = args->coll_type; + size_t msg_size, peer_seg_count; + void *scratch_offset; + ucc_rank_t vpeer, peer, vroot_at_level, root_at_level, num_blocks; + ucc_kn_radix_t loop_step; + ptrdiff_t peer_seg_offset; -UCC_GATHER_KN_PHASE_PROGRESS: - if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + if (task->gather_kn.p.node_type == KN_NODE_EXTRA) { + ucc_assert(ct == UCC_COLL_TYPE_REDUCE); + task->super.status = UCC_OK; return; } UCC_GATHER_KN_GOTO_PHASE(task->gather_kn.phase); - UCC_GATHER_KN_PHASE_INIT: - while (task->gather_kn.dist <= task->gather_kn.max_dist) { + while (!ucc_knomial_pattern_loop_done(p)) { + if (task->tagged.send_posted > 0) { + goto UCC_GATHER_KN_PHASE_PROGRESS; + } + scratch_offset = task->gather_kn.scratch; - if (vrank % task->gather_kn.dist == 0) { - pos = (vrank / task->gather_kn.dist) % radix; - if (pos == 0) { - for (i = 1; i < radix; i++) { - vpeer = vrank + i * task->gather_kn.dist; - msg_count = ucc_min(task->gather_kn.dist, team_size - vpeer); - if (vpeer >= size) { - break; - } else if (vrank != 0) { - msg_size = data_size * msg_count; - scratch_offset = PTR_OFFSET( - scratch_offset, data_size * task->gather_kn.dist); - peer = (vpeer + root) % size; - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, - msg_size, mtype, peer, - team, task), - task, out); - } else { //The root is a particular case because it must aggregate the data sorted by ranks - peer = (vpeer + root) % size; + vroot_at_level = ucc_knomial_pattern_get_base_rank(p, vrank); + if (vroot_at_level == vrank) { + for (loop_step = 1; loop_step < radix; loop_step++) { + vpeer = ucc_knomial_pattern_get_loop_peer(p, vrank, loop_step); + if (vpeer == UCC_KN_PEER_NULL) { + continue; + } + ucc_kn_g_pattern_peer_seg(vpeer, p, &peer_seg_count, + &peer_seg_offset); + peer = INV_VRANK(vpeer, root, tsize); + if (vrank != 0) { + msg_size = peer_seg_count * dt_size; + if (args->coll_type != UCC_COLL_TYPE_GATHER) { scratch_offset = PTR_OFFSET(task->gather_kn.scratch, - data_size * peer); - // check if received data correspond to contiguous ranks - if (msg_count <= team_size - peer) { - msg_size = data_size * msg_count; - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, - msg_size, mtype, - peer, team, task), - task, out); - } else { // in this case, data must be split in two at the destination buffer - msg_size = data_size * (team_size - peer); - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, - msg_size, mtype, - peer, team, task), - task, out); - - msg_size = - data_size * (msg_count - (team_size - peer)); - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb( - task->gather_kn.scratch, msg_size, - mtype, peer, team, task), - task, out); + peer_seg_offset * dt_size); + } else { + scratch_offset = PTR_OFFSET(scratch_offset, + data_size * + task->gather_kn.dist); + } + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, + msg_size, mtype, peer, + team, task), + task, out); + } else { + /* + the root is a particular case because it must aggregate + the data sorted by ranks + */ + scratch_offset = PTR_OFFSET(task->gather_kn.scratch, + data_size * peer); + num_blocks = ucc_min(task->gather_kn.dist, tsize - vpeer); + /* check if received data correspond to contiguous ranks */ + if ((ct == UCC_COLL_TYPE_REDUCE) || + (num_blocks <= tsize - peer)) { + msg_size = peer_seg_count * dt_size; + if (args->coll_type != UCC_COLL_TYPE_GATHER) { + scratch_offset = PTR_OFFSET(task->gather_kn.scratch, + peer_seg_offset * dt_size); } + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, + msg_size, mtype, + peer, team, task), + task, out); + } else { + /* + in this case, data must be split in two + at the destination buffer + */ + msg_size = data_size * (tsize - peer); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, + msg_size, mtype, + peer, team, task), + task, out); + msg_size = data_size * (num_blocks - (tsize - peer)); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(task->gather_kn.scratch, + msg_size, mtype, + peer, team, task), + task, out); } } + } - if (task->gather_kn.dist == 1) { //check if first passage - msg_size = data_size; - if (rank != root) { - status = ucc_mc_memcpy(task->gather_kn.scratch, - args->src.info.buffer, msg_size, - args->src.info.mem_type, mtype); - } else if (!UCC_IS_INPLACE(*args)) { - status = ucc_mc_memcpy( - PTR_OFFSET(task->gather_kn.scratch, data_size * rank), - args->src.info.buffer, msg_size, - args->src.info.mem_type, mtype); - } + if ((ct != UCC_COLL_TYPE_REDUCE) && + ucc_knomial_pattern_loop_first_iteration(p)) { + msg_size = data_size; + if (rank != root) { + status = ucc_mc_memcpy(task->gather_kn.scratch, + args->src.info.buffer, msg_size, + args->src.info.mem_type, mtype); + } else if (!UCC_IS_INPLACE(*args)) { + status = ucc_mc_memcpy( + PTR_OFFSET(task->gather_kn.scratch, data_size * rank), + args->src.info.buffer, msg_size, + args->src.info.mem_type, mtype); + } - if (ucc_unlikely(UCC_OK != status)) { - task->super.status = status; - return; - } + if (ucc_unlikely(UCC_OK != status)) { + task->super.status = status; + return; } } else { - vroot_at_level = vrank - pos * task->gather_kn.dist; - root_at_level = (vroot_at_level + root) % size; - msg_count = ucc_min(task->gather_kn.dist, - team_size - vrank); - msg_size = data_size * msg_count; - if (root_at_level != root || msg_count <= team_size - rank) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->gather_kn.scratch, - msg_size, mtype, - root_at_level, team, task), - task, out); + if (rank == root && ucc_knomial_pattern_loop_first_iteration(p) && !UCC_IS_INPLACE(*args)) { + ucc_kn_g_pattern_peer_seg(vrank, p, &peer_seg_count, + &peer_seg_offset); + status = ucc_mc_memcpy( + PTR_OFFSET(task->gather_kn.scratch, peer_seg_offset * dt_size), + PTR_OFFSET(args->src.info.buffer, peer_seg_offset * dt_size), peer_seg_count * dt_size, + args->src.info.mem_type, mtype); + } + } + } else { + root_at_level = INV_VRANK(vroot_at_level, root, tsize); + num_blocks = ucc_min(task->gather_kn.dist, tsize - vrank); + if ((ct == UCC_COLL_TYPE_REDUCE) || + (root_at_level != root) || + (num_blocks <= tsize - rank)) { + ucc_kn_g_pattern_peer_seg(vrank, p, &peer_seg_count, + &peer_seg_offset); + msg_size = peer_seg_count * dt_size; + if (args->coll_type == UCC_COLL_TYPE_GATHER) { + scratch_offset = task->gather_kn.scratch; } else { - msg_size = data_size * (team_size - rank); - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->gather_kn.scratch, - msg_size, mtype, - root_at_level, team, task), - task, out); - msg_size = data_size * (msg_count - (team_size - rank)); - UCPCHECK_GOTO( - ucc_tl_ucp_send_nb( - PTR_OFFSET(task->gather_kn.scratch, - data_size * (team_size - rank)), - msg_size, mtype, root_at_level, team, task), - task, out); + scratch_offset = PTR_OFFSET(task->gather_kn.scratch, + peer_seg_offset * dt_size); } + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(scratch_offset, + msg_size, mtype, + root_at_level, team, task), + task, out); + } else { + // need to split in this case due to root and tree topology + msg_size = data_size * (tsize - rank); + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->gather_kn.scratch, + msg_size, mtype, + root_at_level, team, task), + task, out); + msg_size = data_size * (num_blocks - (tsize - rank)); + UCPCHECK_GOTO( + ucc_tl_ucp_send_nb(PTR_OFFSET(task->gather_kn.scratch, + data_size * (tsize - rank)), + msg_size, mtype, root_at_level, team, + task), + task, out); } } + +UCC_GATHER_KN_PHASE_PROGRESS: + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + SAVE_STATE(UCC_GATHER_KN_PHASE_PROGRESS); + return; + } task->gather_kn.dist *= radix; - SAVE_STATE(UCC_GATHER_KN_PHASE_INIT); - goto UCC_GATHER_KN_PHASE_PROGRESS; + ucc_kn_g_pattern_next_iter(p); } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); @@ -149,18 +211,31 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_ucp_gather_knomial_start(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_coll_args_t * args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t root = (ucc_rank_t)args->root; - ucc_rank_t rank = UCC_TL_TEAM_RANK(team); - ucc_rank_t size = UCC_TL_TEAM_SIZE(team); + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t root = (ucc_rank_t)args->root; + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t size = UCC_TL_TEAM_SIZE(team); - if (root == rank && UCC_IS_INPLACE(*args)) { + if (root == trank && UCC_IS_INPLACE(*args)) { args->src.info = args->dst.info; args->src.info.count = args->dst.info.count / size; } + if (args->coll_type == UCC_COLL_TYPE_GATHER) { + ucc_kn_g_pattern_init(size, VRANK(trank, root, size), + task->gather_kn.radix, args->src.info.count * size, + &task->gather_kn.p); + } else { + /* reduce srg */ + ucc_assert(args->coll_type == UCC_COLL_TYPE_REDUCE); + task->gather_kn.scratch = args->dst.info.buffer; + ucc_kn_gx_pattern_init(size, VRANK(trank, root, size), + task->gather_kn.radix, args->dst.info.count, + &task->gather_kn.p); + } + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_gather_kn_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); @@ -179,3 +254,105 @@ ucc_status_t ucc_tl_ucp_gather_knomial_finalize(ucc_coll_task_t *coll_task) } return ucc_tl_ucp_coll_finalize(coll_task); } + +ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task, + ucc_kn_radix_t radix) +{ + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_rank_t root = args->root; + ucc_rank_t vrank = VRANK(trank, root, tsize); + ucc_status_t status = UCC_OK; + ucc_memory_type_t mtype; + ucc_datatype_t dt; + size_t count, data_size; + uint32_t buffer_size; + int is_leaf; + + if (UCC_IS_ROOT(*args, trank)) { + count = args->dst.info.count; + dt = args->dst.info.datatype; + mtype = args->dst.info.mem_type; + } else { + count = args->src.info.count; + dt = args->src.info.datatype; + mtype = args->src.info.mem_type; + } + data_size = count * ucc_dt_size(dt); + task->super.post = ucc_tl_ucp_gather_knomial_start; + task->super.progress = ucc_tl_ucp_gather_knomial_progress; + task->super.finalize = ucc_tl_ucp_gather_knomial_finalize; + task->gather_kn.radix = radix; + CALC_KN_TREE_DIST(tsize, task->gather_kn.radix, + task->gather_kn.max_dist); + task->gather_kn.scratch_mc_header = NULL; + + if (args->coll_type == UCC_COLL_TYPE_REDUCE) { + task->gather_kn.scratch = args->dst.info.buffer; + } else { + ucc_assert(args->coll_type == UCC_COLL_TYPE_GATHER); + is_leaf = ((vrank % radix != 0) || (vrank == tsize - 1)); + if (vrank == 0) { + task->gather_kn.scratch = args->dst.info.buffer; + } else if (is_leaf) { + task->gather_kn.scratch = args->src.info.buffer; + } else { + buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, tsize); + status = ucc_mc_alloc(&task->gather_kn.scratch_mc_header, + buffer_size * data_size, mtype); + task->gather_kn.scratch = task->gather_kn.scratch_mc_header->addr; + } + } + + return status; +} + +ucc_status_t ucc_tl_ucp_gather_knomial_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team); + ucc_tl_ucp_task_t *task; + ucc_status_t status; + ucc_kn_radix_t radix; + + task = ucc_tl_ucp_init_task(coll_args, team); + if (ucc_unlikely(!task)) { + return UCC_ERR_NO_MEMORY; + } + + radix = ucc_min(UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.gather_kn_radix, tsize); + + status = ucc_tl_ucp_gather_knomial_init_common(task, radix); + if (ucc_unlikely(status != UCC_OK)) { + ucc_tl_ucp_put_task(task); + return status; + } + *task_h = &task->super; + return UCC_OK; +} + +ucc_status_t ucc_tl_ucp_gather_knomial_init_r(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h, + ucc_kn_radix_t radix) +{ + ucc_tl_ucp_task_t *task; + ucc_status_t status; + + task = ucc_tl_ucp_init_task(coll_args, team); + if (ucc_unlikely(!task)) { + return UCC_ERR_NO_MEMORY; + } + + status = ucc_tl_ucp_gather_knomial_init_common(task, radix); + if (ucc_unlikely(status != UCC_OK)) { + ucc_tl_ucp_put_task(task); + return status; + } + *task_h = &task->super; + return UCC_OK; +} diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 9668e46183..2769244d39 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -237,10 +237,11 @@ typedef struct ucc_tl_ucp_task { ucc_ee_executor_t *executor; } reduce_dbt; struct { + int phase; + ucc_knomial_pattern_t p; ucc_rank_t dist; ucc_rank_t max_dist; uint32_t radix; - int phase; void * scratch; ucc_mc_buffer_header_t *scratch_mc_header; } gather_kn; @@ -320,17 +321,17 @@ static inline void ucc_tl_ucp_put_task(ucc_tl_ucp_task_t *task) } static inline ucc_status_t -ucc_tl_ucp_get_schedule(ucc_tl_ucp_team_t *team, - ucc_base_coll_args_t *args, +ucc_tl_ucp_get_schedule(ucc_tl_ucp_team_t *team, + ucc_base_coll_args_t *args, ucc_tl_ucp_schedule_t **schedule) { - ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); *schedule = ucc_mpool_get(&ctx->req_mp); - if (ucc_unlikely(!(*schedule))) { return UCC_ERR_NO_MEMORY; } + UCC_TL_UCP_PROFILE_REQUEST_NEW(schedule, "tl_ucp_sched", 0); return ucc_schedule_init(&((*schedule)->super.super), args, &team->super.super); @@ -342,7 +343,6 @@ static inline void ucc_tl_ucp_put_schedule(ucc_schedule_t *schedule) ucc_mpool_put(schedule); } - ucc_status_t ucc_tl_ucp_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * team, ucc_coll_task_t ** task_h);