From 70a4ea377889598d62e17711b0abc1ccc904b125 Mon Sep 17 00:00:00 2001
From: Edgar Gabriel <edgar.gabriel@amd.com>
Date: Sun, 29 Dec 2024 17:55:56 -0600
Subject: [PATCH] coll/accelerator: add reduce_scatter

add support for MPI_Reduce_scatter

Signed-off-by: Edgar Gabriel <edgar.gabriel@amd.com>
---
 ompi/mca/coll/accelerator/Makefile.am         |   3 +-
 ompi/mca/coll/accelerator/coll_accelerator.h  |   7 ++
 .../accelerator/coll_accelerator_module.c     |   5 +-
 .../coll_accelerator_reduce_scatter.c         | 108 ++++++++++++++++++
 4 files changed, 121 insertions(+), 2 deletions(-)
 create mode 100644 ompi/mca/coll/accelerator/coll_accelerator_reduce_scatter.c

diff --git a/ompi/mca/coll/accelerator/Makefile.am b/ompi/mca/coll/accelerator/Makefile.am
index e3621c1d05a..d9b41006530 100644
--- a/ompi/mca/coll/accelerator/Makefile.am
+++ b/ompi/mca/coll/accelerator/Makefile.am
@@ -12,7 +12,8 @@
 #
 
 sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
-          coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
+          coll_accelerator_reduce_scatter_block.c coll_accelerator_reduce_scatter.c       \
+          coll_accelerator_component.c \
           coll_accelerator_scan.c coll_accelerator_exscan.c coll_accelerator.h
 
 # Make the output library in this directory, and name it either
diff --git a/ompi/mca/coll/accelerator/coll_accelerator.h b/ompi/mca/coll/accelerator/coll_accelerator.h
index 70d971cc9a8..a719746d8b6 100644
--- a/ompi/mca/coll/accelerator/coll_accelerator.h
+++ b/ompi/mca/coll/accelerator/coll_accelerator.h
@@ -78,6 +78,13 @@ mca_coll_accelerator_reduce_scatter_block(const void *sbuf, void *rbuf, size_t r
                                    struct ompi_communicator_t *comm,
                                    mca_coll_base_module_t *module);
 
+int
+mca_coll_accelerator_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
+                                   struct ompi_datatype_t *dtype,
+                                   struct ompi_op_t *op,
+                                   struct ompi_communicator_t *comm,
+                                   mca_coll_base_module_t *module);
+
 
 /* Checks the type of pointer
  *
diff --git a/ompi/mca/coll/accelerator/coll_accelerator_module.c b/ompi/mca/coll/accelerator/coll_accelerator_module.c
index 4005f6cdec9..862eaed8ad7 100644
--- a/ompi/mca/coll/accelerator/coll_accelerator_module.c
+++ b/ompi/mca/coll/accelerator/coll_accelerator_module.c
@@ -6,7 +6,7 @@
  * Copyright (c) 2014-2024 NVIDIA Corporation.  All rights reserved.
  * Copyright (c) 2019      Research Organization for Information Science
  *                         and Technology (RIST).  All rights reserved.
- * Copyright (c) 2023      Advanced Micro Devices, Inc. All rights reserved.
+ * Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
  * Copyright (c) 2024      Triad National Security, LLC. All rights reserved.
  * $COPYRIGHT$
  *
@@ -96,6 +96,7 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
     accelerator_module->super.coll_allreduce  = mca_coll_accelerator_allreduce;
     accelerator_module->super.coll_reduce     = mca_coll_accelerator_reduce;
     accelerator_module->super.coll_reduce_local         = mca_coll_accelerator_reduce_local;
+    accelerator_module->super.coll_reduce_scatter       = mca_coll_accelerator_reduce_scatter;
     accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
     if (!OMPI_COMM_IS_INTER(comm)) {
         accelerator_module->super.coll_scan       = mca_coll_accelerator_scan;
@@ -144,6 +145,7 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
     ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
     ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
     ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local);
+    ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter);
     ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
     if (!OMPI_COMM_IS_INTER(comm)) {
         /* MPI does not define scan/exscan on intercommunicators */
@@ -163,6 +165,7 @@ mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
     ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
     ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
     ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_local);
+    ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter);
     ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
     if (!OMPI_COMM_IS_INTER(comm))
     {
diff --git a/ompi/mca/coll/accelerator/coll_accelerator_reduce_scatter.c b/ompi/mca/coll/accelerator/coll_accelerator_reduce_scatter.c
new file mode 100644
index 00000000000..70200c567a8
--- /dev/null
+++ b/ompi/mca/coll/accelerator/coll_accelerator_reduce_scatter.c
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2014-2017 The University of Tennessee and The University
+ *                         of Tennessee Research Foundation.  All rights
+ *                         reserved.
+ * Copyright (c) 2014-2015 NVIDIA Corporation.  All rights reserved.
+ * Copyright (c) 2022      Amazon.com, Inc. or its affiliates.  All Rights reserved.
+ * Copyright (c) 2024      Triad National Security, LLC. All rights reserved.
+ * Copyright (c) 2024      Advanced Micro Devices, Inc. All Rights reserved.
+ * $COPYRIGHT$
+ *
+ * Additional copyrights may follow
+ *
+ * $HEADER$
+ */
+
+#include "ompi_config.h"
+#include "coll_accelerator.h"
+
+#include <stdio.h>
+
+#include "ompi/op/op.h"
+#include "opal/datatype/opal_convertor.h"
+
+/*
+ *	reduce_scatter_block
+ *
+ *	Function:	- reduce then scatter
+ *	Accepts:	- same as MPI_Reduce_scatter()
+ *	Returns:	- MPI_SUCCESS or error code
+ *
+ * Algorithm:
+ *     reduce and scatter (needs to be cleaned
+ *     up at some point)
+ */
+int
+mca_coll_accelerator_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
+                                   struct ompi_datatype_t *dtype,
+                                   struct ompi_op_t *op,
+                                   struct ompi_communicator_t *comm,
+                                   mca_coll_base_module_t *module)
+{
+    mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
+    ptrdiff_t gap;
+    char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
+    int sbuf_dev, rbuf_dev;
+    size_t sbufsize, rbufsize, elemsize;
+    int rc, i;
+    int comm_size = ompi_comm_size(comm);
+    int total_count = 0;
+    
+    elemsize = opal_datatype_span(&dtype->super, 1, &gap);
+    for (i = 0; i < comm_size; i++) {
+	total_count += ompi_count_array_get(rcounts, i);
+    }
+    sbufsize = elemsize * total_count;
+
+    rc = mca_coll_accelerator_check_buf((void *)sbuf, &sbuf_dev);
+    if (0 > rc) {
+        return rc;
+    }
+    if ((MPI_IN_PLACE != sbuf) && (0 < rc)) {
+        sbuf1 = (char*)malloc(sbufsize);
+        if (NULL == sbuf1) {
+            return OMPI_ERR_OUT_OF_RESOURCE;
+        }
+        mca_coll_accelerator_memcpy(sbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, sbuf, sbuf_dev, sbufsize,
+                                    MCA_ACCELERATOR_TRANSFER_DTOH);
+        sbuf = sbuf1 - gap;
+    }
+
+    rc = mca_coll_accelerator_check_buf(rbuf, &rbuf_dev);
+    if (0 > rc) {
+        goto exit;
+    }
+    rbufsize = elemsize * ompi_count_array_get(rcounts, ompi_comm_rank(comm));
+    if (0 < rc) {
+        rbuf1 = (char*)malloc(rbufsize);
+        if (NULL == rbuf1) {
+            rc = OMPI_ERR_OUT_OF_RESOURCE;
+	    goto exit;
+        }
+        mca_coll_accelerator_memcpy(rbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, rbuf, rbuf_dev, rbufsize,
+                                    MCA_ACCELERATOR_TRANSFER_DTOH);
+        rbuf2 = rbuf; /* save away original buffer */
+        rbuf = rbuf1 - gap;
+    }
+    rc = s->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, dtype, op, comm,
+                                       s->c_coll.coll_reduce_scatter_block_module);
+    if (0 > rc) {
+        goto exit;
+    }
+
+    if (NULL != rbuf1) {
+        mca_coll_accelerator_memcpy(rbuf2, rbuf_dev, rbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, rbufsize,
+                                    MCA_ACCELERATOR_TRANSFER_HTOD);
+    }
+
+ exit:
+    if (NULL != sbuf1) {
+        free(sbuf1);
+    }
+    if (NULL != rbuf1) {
+        free(rbuf1);
+    }
+
+    return rc;
+}
+