Skip to content

Commit

Permalink
coll/accelerator: add reduce_scatter
Browse files Browse the repository at this point in the history
add support for MPI_Reduce_scatter

Signed-off-by: Edgar Gabriel <[email protected]>
  • Loading branch information
edgargabriel committed Dec 30, 2024
1 parent 88cd4a5 commit 70a4ea3
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 2 deletions.
3 changes: 2 additions & 1 deletion ompi/mca/coll/accelerator/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#

sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
coll_accelerator_reduce_scatter_block.c coll_accelerator_reduce_scatter.c \
coll_accelerator_component.c \
coll_accelerator_scan.c coll_accelerator_exscan.c coll_accelerator.h

# Make the output library in this directory, and name it either
Expand Down
7 changes: 7 additions & 0 deletions ompi/mca/coll/accelerator/coll_accelerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ mca_coll_accelerator_reduce_scatter_block(const void *sbuf, void *rbuf, size_t r
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

int
mca_coll_accelerator_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);


/* Checks the type of pointer
*
Expand Down
5 changes: 4 additions & 1 deletion ompi/mca/coll/accelerator/coll_accelerator_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* $COPYRIGHT$
*
Expand Down Expand Up @@ -96,6 +96,7 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
accelerator_module->super.coll_reduce_local = mca_coll_accelerator_reduce_local;
accelerator_module->super.coll_reduce_scatter = mca_coll_accelerator_reduce_scatter;
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
if (!OMPI_COMM_IS_INTER(comm)) {
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
Expand Down Expand Up @@ -144,6 +145,7 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm)) {
/* MPI does not define scan/exscan on intercommunicators */
Expand All @@ -163,6 +165,7 @@ mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_local);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm))
{
Expand Down
108 changes: 108 additions & 0 deletions ompi/mca/coll/accelerator/coll_accelerator_reduce_scatter.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2014-2017 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2022 Amazon.com, Inc. or its affiliates. All Rights reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
* $HEADER$
*/

#include "ompi_config.h"
#include "coll_accelerator.h"

#include <stdio.h>

#include "ompi/op/op.h"
#include "opal/datatype/opal_convertor.h"

/*
* reduce_scatter_block
*
* Function: - reduce then scatter
* Accepts: - same as MPI_Reduce_scatter()
* Returns: - MPI_SUCCESS or error code
*
* Algorithm:
* reduce and scatter (needs to be cleaned
* up at some point)
*/
int
mca_coll_accelerator_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
ptrdiff_t gap;
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
int sbuf_dev, rbuf_dev;
size_t sbufsize, rbufsize, elemsize;
int rc, i;
int comm_size = ompi_comm_size(comm);
int total_count = 0;

elemsize = opal_datatype_span(&dtype->super, 1, &gap);
for (i = 0; i < comm_size; i++) {
total_count += ompi_count_array_get(rcounts, i);
}
sbufsize = elemsize * total_count;

rc = mca_coll_accelerator_check_buf((void *)sbuf, &sbuf_dev);
if (0 > rc) {
return rc;
}
if ((MPI_IN_PLACE != sbuf) && (0 < rc)) {
sbuf1 = (char*)malloc(sbufsize);
if (NULL == sbuf1) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
mca_coll_accelerator_memcpy(sbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, sbuf, sbuf_dev, sbufsize,
MCA_ACCELERATOR_TRANSFER_DTOH);
sbuf = sbuf1 - gap;
}

rc = mca_coll_accelerator_check_buf(rbuf, &rbuf_dev);
if (0 > rc) {
goto exit;
}
rbufsize = elemsize * ompi_count_array_get(rcounts, ompi_comm_rank(comm));
if (0 < rc) {
rbuf1 = (char*)malloc(rbufsize);
if (NULL == rbuf1) {
rc = OMPI_ERR_OUT_OF_RESOURCE;
goto exit;
}
mca_coll_accelerator_memcpy(rbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, rbuf, rbuf_dev, rbufsize,
MCA_ACCELERATOR_TRANSFER_DTOH);
rbuf2 = rbuf; /* save away original buffer */
rbuf = rbuf1 - gap;
}
rc = s->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, dtype, op, comm,
s->c_coll.coll_reduce_scatter_block_module);
if (0 > rc) {
goto exit;
}

if (NULL != rbuf1) {
mca_coll_accelerator_memcpy(rbuf2, rbuf_dev, rbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, rbufsize,
MCA_ACCELERATOR_TRANSFER_HTOD);
}

exit:
if (NULL != sbuf1) {
free(sbuf1);
}
if (NULL != rbuf1) {
free(rbuf1);
}

return rc;
}

0 comments on commit 70a4ea3

Please sign in to comment.