diff --git a/prov/efa/src/efa.h b/prov/efa/src/efa.h index 4d8e982355c..e1cf716cb05 100644 --- a/prov/efa/src/efa.h +++ b/prov/efa/src/efa.h @@ -107,6 +107,41 @@ struct efa_fabric { #endif }; +struct efa_context { + uint64_t completion_flags; + fi_addr_t addr; +}; + +#if defined(static_assert) +static_assert(sizeof(struct efa_context) <= sizeof(struct fi_context2), + "efa_context must not be larger than fi_context2"); +#endif + +/** + * Prepare and return a pointer to an EFA context structure. + * + * @param context Pointer to the msg context. + * @param addr Peer address associated with the operation. + * @param flags Operation flags (e.g., FI_COMPLETION). + * @param completion_flags Completion flags reported in the cq entry. + * @return A pointer to an initialized EFA context structure, + * or NULL if context is invalid or FI_COMPLETION is not set. + */ +static inline struct efa_context *efa_fill_context(const void *context, + fi_addr_t addr, + uint64_t flags, + uint64_t completion_flags) +{ + if (!context || !(flags & FI_COMPLETION)) + return NULL; + + struct efa_context *efa_context = (struct efa_context *) context; + efa_context->completion_flags = completion_flags; + efa_context->addr = addr; + + return efa_context; +} + static inline int efa_str_to_ep_addr(const char *node, const char *service, struct efa_ep_addr *addr) { diff --git a/prov/efa/src/efa_cq.c b/prov/efa/src/efa_cq.c index a5b737d89ac..eeffe60cf3f 100644 --- a/prov/efa/src/efa_cq.c +++ b/prov/efa/src/efa_cq.c @@ -36,7 +36,10 @@ static void efa_cq_construct_cq_entry(struct ibv_cq_ex *ibv_cqx, struct fi_cq_tagged_entry *entry) { entry->op_context = (void *)ibv_cqx->wr_id; - entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx)); + if (ibv_cqx->wr_id) + entry->flags = ((struct efa_context *) ibv_cqx->wr_id)->completion_flags; + else + entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx)); entry->len = ibv_wc_read_byte_len(ibv_cqx); entry->buf = NULL; entry->data = 0; @@ -81,8 +84,7 @@ static void efa_cq_handle_error(struct efa_base_ep *base_ep, err_entry.prov_errno = prov_errno; if (is_tx) - // TODO: get correct peer addr for TX operation - addr = FI_ADDR_NOTAVAIL; + addr = ibv_cq_ex->wr_id ? ((struct efa_context *)ibv_cq_ex->wr_id)->addr : FI_ADDR_NOTAVAIL; else addr = efa_av_reverse_lookup(base_ep->av, ibv_wc_read_slid(ibv_cq_ex), diff --git a/prov/efa/src/efa_msg.c b/prov/efa/src/efa_msg.c index c2af757e112..5d5768c8ff1 100644 --- a/prov/efa/src/efa_msg.c +++ b/prov/efa/src/efa_msg.c @@ -101,7 +101,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi wr = &base_ep->efa_recv_wr_vec[wr_index].wr; wr->num_sge = msg->iov_count; wr->sg_list = base_ep->efa_recv_wr_vec[wr_index].sge; - wr->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + wr->wr_id = (uintptr_t) efa_fill_context(msg->context, msg->addr, flags, + FI_RECV | FI_MSG); for (i = 0; i < msg->iov_count; i++) { addr = (uintptr_t)msg->msg_iov[i].iov_base; @@ -224,7 +225,8 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi base_ep->is_wr_started = true; } - qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context( + msg->context, msg->addr, flags, FI_SEND | FI_MSG); if (flags & FI_REMOTE_CQ_DATA) { ibv_wr_send_imm(qp->ibv_qp_ex, msg->data); diff --git a/prov/efa/src/efa_rma.c b/prov/efa/src/efa_rma.c index 8fee3a2021b..da33b44350f 100644 --- a/prov/efa/src/efa_rma.c +++ b/prov/efa/src/efa_rma.c @@ -90,7 +90,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep, ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; } - qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + + qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context( + msg->context, msg->addr, flags, FI_RMA | FI_READ); /* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */ ibv_wr_rdma_read(qp->ibv_qp_ex, msg->rma_iov[0].key, msg->rma_iov[0].addr); @@ -225,7 +227,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep, ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; } - qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + + qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context( + msg->context, msg->addr, flags, FI_RMA | FI_WRITE); if (flags & FI_REMOTE_CQ_DATA) { ibv_wr_rdma_write_imm(qp->ibv_qp_ex, msg->rma_iov[0].key, diff --git a/prov/efa/src/rdm/efa_rdm_pke.h b/prov/efa/src/rdm/efa_rdm_pke.h index 223822ce595..3bd0e51390d 100644 --- a/prov/efa/src/rdm/efa_rdm_pke.h +++ b/prov/efa/src/rdm/efa_rdm_pke.h @@ -195,7 +195,7 @@ struct efa_rdm_pke { _Alignas(EFA_RDM_PKE_ALIGNMENT) char wiredata[0]; }; -#if defined(static_assert) && defined(__x86_64__) +#if defined(static_assert) static_assert(sizeof (struct efa_rdm_pke) % EFA_RDM_PKE_ALIGNMENT == 0, "efa_rdm_pke alignment check"); #endif diff --git a/prov/efa/src/rdm/efa_rdm_protocol.h b/prov/efa/src/rdm/efa_rdm_protocol.h index 8840ce5f401..a4e608a5180 100644 --- a/prov/efa/src/rdm/efa_rdm_protocol.h +++ b/prov/efa/src/rdm/efa_rdm_protocol.h @@ -115,7 +115,7 @@ struct efa_ep_addr { #define EFA_RDM_RUNT_PKT_END 148 #define EFA_RDM_EXTRA_REQ_PKT_END 148 -#if defined(static_assert) && defined(__x86_64__) +#if defined(static_assert) #define EFA_RDM_ENSURE_HEADER_SIZE(hdr, size) \ static_assert(sizeof (struct hdr) == (size), #hdr " size check") #else diff --git a/prov/efa/src/rdm/efa_rdm_util.h b/prov/efa/src/rdm/efa_rdm_util.h index 7c3daa3432f..123fda9c59f 100644 --- a/prov/efa/src/rdm/efa_rdm_util.h +++ b/prov/efa/src/rdm/efa_rdm_util.h @@ -10,7 +10,7 @@ #define EFA_RDM_MSG_PREFIX_SIZE (sizeof(struct efa_rdm_pke) + sizeof(struct efa_rdm_eager_msgrtm_hdr) + EFA_RDM_REQ_OPT_RAW_ADDR_HDR_SIZE) -#if defined(static_assert) && defined(__x86_64__) +#if defined(static_assert) static_assert(EFA_RDM_MSG_PREFIX_SIZE % 8 == 0, "message prefix size alignment check"); #endif diff --git a/prov/efa/test/efa_unit_test_cq.c b/prov/efa/test/efa_unit_test_cq.c index e69fb8b432e..0a96c64c67e 100644 --- a/prov/efa/test/efa_unit_test_cq.c +++ b/prov/efa/test/efa_unit_test_cq.c @@ -813,7 +813,8 @@ void test_ibv_cq_ex_read_ignore_removed_peer() #endif static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr, - int ibv_wc_opcode, int status, int vendor_error) + int ibv_wc_opcode, int status, int vendor_error, + struct efa_context *ctx) { int ret; size_t raw_addr_len = sizeof(struct efa_ep_addr); @@ -847,7 +848,7 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr, if (ibv_wc_opcode == IBV_WC_RECV) { ibv_cqx = container_of(base_ep->util_ep.rx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex; ibv_cqx->start_poll = &efa_mock_ibv_start_poll_return_mock; - ibv_cqx->wr_id = (uintptr_t)12345; + ctx->completion_flags = FI_RECV | FI_MSG; will_return(efa_mock_ibv_start_poll_return_mock, 0); ibv_cqx->status = status; } else { @@ -855,8 +856,11 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr, /* this mock will set ibv_cq_ex->wr_id to the wr_id of the head of global send_wr, * and set ibv_cq_ex->status to mock value */ ibv_cqx->start_poll = &efa_mock_ibv_start_poll_use_saved_send_wr_with_mock_status; + ctx->completion_flags = FI_SEND | FI_MSG; will_return(efa_mock_ibv_start_poll_use_saved_send_wr_with_mock_status, status); } + ctx->addr = *addr; + ibv_cqx->wr_id = (uintptr_t) ctx; ibv_cqx->next_poll = &efa_mock_ibv_next_poll_return_mock; ibv_cqx->end_poll = &efa_mock_ibv_end_poll_check_mock; @@ -894,19 +898,29 @@ void test_efa_cq_read_send_success(struct efa_resource **state) { struct efa_resource *resource = *state; struct efa_unit_test_buff send_buff; + struct efa_base_ep *base_ep; + struct efa_context *efa_context; + struct fi_context2 ctx; struct fi_cq_data_entry cq_entry; fi_addr_t addr; int ret; - test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0); + test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0, + (struct efa_context *) &ctx); efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */); assert_int_equal(g_ibv_submitted_wr_id_cnt, 0); ret = fi_send(resource->ep, send_buff.buff, send_buff.size, - fi_mr_desc(send_buff.mr), addr, (void *) 12345); + fi_mr_desc(send_buff.mr), addr, &ctx); assert_int_equal(ret, 0); assert_int_equal(g_ibv_submitted_wr_id_cnt, 1); + base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid); + efa_context = (struct efa_context *) base_ep->qp->ibv_qp_ex->wr_id; + assert_true(efa_context->completion_flags & FI_SEND); + assert_true(efa_context->completion_flags & FI_MSG); + assert_true(efa_context->addr == addr); + ret = fi_cq_read(resource->cq, &cq_entry, 1); /* fi_cq_read() called efa_mock_ibv_start_poll_use_saved_send_wr(), which pulled one send_wr from g_ibv_submitted_wr_idv=_vec */ assert_int_equal(g_ibv_submitted_wr_id_cnt, 0); @@ -923,17 +937,27 @@ void test_efa_cq_read_recv_success(struct efa_resource **state) { struct efa_resource *resource = *state; struct efa_unit_test_buff recv_buff; + struct efa_base_ep *base_ep; + struct efa_context *efa_context; struct fi_cq_data_entry cq_entry; + struct fi_context2 ctx; fi_addr_t addr; int ret; - test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0); + test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0, + (struct efa_context *) &ctx); efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */); ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size, - fi_mr_desc(recv_buff.mr), addr, NULL); + fi_mr_desc(recv_buff.mr), addr, &ctx); assert_int_equal(ret, 0); + base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid); + efa_context = (struct efa_context *) base_ep->efa_recv_wr_vec[base_ep->recv_wr_index].wr.wr_id; + assert_true(efa_context->completion_flags & FI_RECV); + assert_true(efa_context->completion_flags & FI_MSG); + assert_true(efa_context->addr == addr); + ret = fi_cq_read(resource->cq, &cq_entry, 1); assert_int_equal(ret, 1); @@ -973,20 +997,29 @@ void test_efa_cq_read_send_failure(struct efa_resource **state) { struct efa_resource *resource = *state; struct efa_unit_test_buff send_buff; + struct efa_base_ep *base_ep; + struct efa_context *efa_context; struct fi_cq_data_entry cq_entry; + struct fi_context2 ctx; fi_addr_t addr; int ret; test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_GENERAL_ERR, - EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE); + EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx); efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */); assert_int_equal(g_ibv_submitted_wr_id_cnt, 0); ret = fi_send(resource->ep, send_buff.buff, send_buff.size, - fi_mr_desc(send_buff.mr), addr, (void *) 12345); + fi_mr_desc(send_buff.mr), addr, &ctx); assert_int_equal(ret, 0); assert_int_equal(g_ibv_submitted_wr_id_cnt, 1); + base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid); + efa_context = (struct efa_context *) base_ep->qp->ibv_qp_ex->wr_id; + assert_true(efa_context->completion_flags & FI_SEND); + assert_true(efa_context->completion_flags & FI_MSG); + assert_true(efa_context->addr == addr); + ret = fi_cq_read(resource->cq, &cq_entry, 1); /* fi_cq_read() called efa_mock_ibv_start_poll_use_saved_send_wr(), which pulled one send_wr from g_ibv_submitted_wr_idv=_vec */ assert_int_equal(g_ibv_submitted_wr_id_cnt, 0); @@ -1010,18 +1043,27 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state) { struct efa_resource *resource = *state; struct efa_unit_test_buff recv_buff; + struct efa_base_ep *base_ep; + struct efa_context *efa_context; struct fi_cq_data_entry cq_entry; + struct fi_context2 ctx; fi_addr_t addr; int ret; test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_GENERAL_ERR, - EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE); + EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx); efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */); ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size, - fi_mr_desc(recv_buff.mr), addr, NULL); + fi_mr_desc(recv_buff.mr), addr, &ctx); assert_int_equal(ret, 0); + base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid); + efa_context = (struct efa_context *) base_ep->efa_recv_wr_vec[base_ep->recv_wr_index].wr.wr_id; + assert_true(efa_context->completion_flags & FI_RECV); + assert_true(efa_context->completion_flags & FI_MSG); + assert_true(efa_context->addr == addr); + ret = fi_cq_read(resource->cq, &cq_entry, 1); assert_int_equal(ret, -FI_EAVAIL);