Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prov/efa: Implement FI_CONTEXT2 in EFA Direct #10707

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions prov/efa/src/efa.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,41 @@ struct efa_fabric {
#endif
};

struct efa_context {
uint64_t completion_flags;
fi_addr_t addr;
};

#if defined(static_assert)
static_assert(sizeof(struct efa_context) <= sizeof(struct fi_context2),
"efa_context must not be larger than fi_context2");
#endif

/**
* Prepare and return a pointer to an EFA context structure.
*
* @param context Pointer to the msg context.
* @param addr Peer address associated with the operation.
* @param flags Operation flags (e.g., FI_COMPLETION).
* @param completion_flags Completion flags reported in the cq entry.
* @return A pointer to an initialized EFA context structure,
* or NULL if context is invalid or FI_COMPLETION is not set.
*/
static inline struct efa_context *efa_fill_context(const void *context,
fi_addr_t addr,
uint64_t flags,
uint64_t completion_flags)
{
if (!context || !(flags & FI_COMPLETION))
return NULL;

struct efa_context *efa_context = (struct efa_context *) context;
efa_context->completion_flags = completion_flags;
efa_context->addr = addr;

return efa_context;
}

static inline
int efa_str_to_ep_addr(const char *node, const char *service, struct efa_ep_addr *addr)
{
Expand Down
8 changes: 5 additions & 3 deletions prov/efa/src/efa_cq.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ static void efa_cq_construct_cq_entry(struct ibv_cq_ex *ibv_cqx,
struct fi_cq_tagged_entry *entry)
{
entry->op_context = (void *)ibv_cqx->wr_id;
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
if (ibv_cqx->wr_id)
entry->flags = ((struct efa_context *) ibv_cqx->wr_id)->completion_flags;
else
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
entry->len = ibv_wc_read_byte_len(ibv_cqx);
entry->buf = NULL;
entry->data = 0;
Expand Down Expand Up @@ -81,8 +84,7 @@ static void efa_cq_handle_error(struct efa_base_ep *base_ep,
err_entry.prov_errno = prov_errno;

if (is_tx)
// TODO: get correct peer addr for TX operation
addr = FI_ADDR_NOTAVAIL;
addr = ibv_cq_ex->wr_id ? ((struct efa_context *)ibv_cq_ex->wr_id)->addr : FI_ADDR_NOTAVAIL;
else
addr = efa_av_reverse_lookup(base_ep->av,
ibv_wc_read_slid(ibv_cq_ex),
Expand Down
6 changes: 4 additions & 2 deletions prov/efa/src/efa_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi
wr = &base_ep->efa_recv_wr_vec[wr_index].wr;
wr->num_sge = msg->iov_count;
wr->sg_list = base_ep->efa_recv_wr_vec[wr_index].sge;
wr->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
wr->wr_id = (uintptr_t) efa_fill_context(msg->context, msg->addr, flags,
FI_RECV | FI_MSG);

for (i = 0; i < msg->iov_count; i++) {
addr = (uintptr_t)msg->msg_iov[i].iov_base;
Expand Down Expand Up @@ -224,7 +225,8 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi
base_ep->is_wr_started = true;
}

qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_SEND | FI_MSG);

if (flags & FI_REMOTE_CQ_DATA) {
ibv_wr_send_imm(qp->ibv_qp_ex, msg->data);
Expand Down
8 changes: 6 additions & 2 deletions prov/efa/src/efa_rma.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep,
ibv_wr_start(qp->ibv_qp_ex);
base_ep->is_wr_started = true;
}
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);

qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_RMA | FI_READ);

/* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */
ibv_wr_rdma_read(qp->ibv_qp_ex, msg->rma_iov[0].key, msg->rma_iov[0].addr);
Expand Down Expand Up @@ -225,7 +227,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,
ibv_wr_start(qp->ibv_qp_ex);
base_ep->is_wr_started = true;
}
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);

qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_RMA | FI_WRITE);

if (flags & FI_REMOTE_CQ_DATA) {
ibv_wr_rdma_write_imm(qp->ibv_qp_ex, msg->rma_iov[0].key,
Expand Down
2 changes: 1 addition & 1 deletion prov/efa/src/rdm/efa_rdm_pke.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ struct efa_rdm_pke {
_Alignas(EFA_RDM_PKE_ALIGNMENT) char wiredata[0];
};

#if defined(static_assert) && defined(__x86_64__)
#if defined(static_assert)
static_assert(sizeof (struct efa_rdm_pke) % EFA_RDM_PKE_ALIGNMENT == 0, "efa_rdm_pke alignment check");
#endif

Expand Down
2 changes: 1 addition & 1 deletion prov/efa/src/rdm/efa_rdm_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ struct efa_ep_addr {
#define EFA_RDM_RUNT_PKT_END 148
#define EFA_RDM_EXTRA_REQ_PKT_END 148

#if defined(static_assert) && defined(__x86_64__)
#if defined(static_assert)
#define EFA_RDM_ENSURE_HEADER_SIZE(hdr, size) \
static_assert(sizeof (struct hdr) == (size), #hdr " size check")
#else
Expand Down
2 changes: 1 addition & 1 deletion prov/efa/src/rdm/efa_rdm_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#define EFA_RDM_MSG_PREFIX_SIZE (sizeof(struct efa_rdm_pke) + sizeof(struct efa_rdm_eager_msgrtm_hdr) + EFA_RDM_REQ_OPT_RAW_ADDR_HDR_SIZE)

#if defined(static_assert) && defined(__x86_64__)
#if defined(static_assert)
static_assert(EFA_RDM_MSG_PREFIX_SIZE % 8 == 0, "message prefix size alignment check");
#endif

Expand Down
62 changes: 52 additions & 10 deletions prov/efa/test/efa_unit_test_cq.c
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ void test_ibv_cq_ex_read_ignore_removed_peer()
#endif

static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
int ibv_wc_opcode, int status, int vendor_error)
int ibv_wc_opcode, int status, int vendor_error,
struct efa_context *ctx)
{
int ret;
size_t raw_addr_len = sizeof(struct efa_ep_addr);
Expand Down Expand Up @@ -847,16 +848,19 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
if (ibv_wc_opcode == IBV_WC_RECV) {
ibv_cqx = container_of(base_ep->util_ep.rx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex;
ibv_cqx->start_poll = &efa_mock_ibv_start_poll_return_mock;
ibv_cqx->wr_id = (uintptr_t)12345;
ctx->completion_flags = FI_RECV | FI_MSG;
will_return(efa_mock_ibv_start_poll_return_mock, 0);
ibv_cqx->status = status;
} else {
ibv_cqx = container_of(base_ep->util_ep.tx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex;
/* this mock will set ibv_cq_ex->wr_id to the wr_id of the head of global send_wr,
* and set ibv_cq_ex->status to mock value */
ibv_cqx->start_poll = &efa_mock_ibv_start_poll_use_saved_send_wr_with_mock_status;
ctx->completion_flags = FI_SEND | FI_MSG;
will_return(efa_mock_ibv_start_poll_use_saved_send_wr_with_mock_status, status);
}
ctx->addr = *addr;
ibv_cqx->wr_id = (uintptr_t) ctx;

ibv_cqx->next_poll = &efa_mock_ibv_next_poll_return_mock;
ibv_cqx->end_poll = &efa_mock_ibv_end_poll_check_mock;
Expand Down Expand Up @@ -894,19 +898,29 @@ void test_efa_cq_read_send_success(struct efa_resource **state)
{
struct efa_resource *resource = *state;
struct efa_unit_test_buff send_buff;
struct efa_base_ep *base_ep;
struct efa_context *efa_context;
struct fi_context2 ctx;
struct fi_cq_data_entry cq_entry;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0);
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0,
(struct efa_context *) &ctx);
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);

assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
fi_mr_desc(send_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but generally it worth checking the completion flags in the cq entry to be correct.

base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
efa_context = (struct efa_context *) base_ep->qp->ibv_qp_ex->wr_id;
assert_true(efa_context->completion_flags & FI_SEND);
assert_true(efa_context->completion_flags & FI_MSG);
assert_true(efa_context->addr == addr);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
/* fi_cq_read() called efa_mock_ibv_start_poll_use_saved_send_wr(), which pulled one send_wr from g_ibv_submitted_wr_idv=_vec */
assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
Expand All @@ -923,17 +937,27 @@ void test_efa_cq_read_recv_success(struct efa_resource **state)
{
struct efa_resource *resource = *state;
struct efa_unit_test_buff recv_buff;
struct efa_base_ep *base_ep;
struct efa_context *efa_context;
struct fi_cq_data_entry cq_entry;
struct fi_context2 ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0);
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0,
(struct efa_context *) &ctx);
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);

ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
fi_mr_desc(recv_buff.mr), addr, NULL);
fi_mr_desc(recv_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);

base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
efa_context = (struct efa_context *) base_ep->efa_recv_wr_vec[base_ep->recv_wr_index].wr.wr_id;
assert_true(efa_context->completion_flags & FI_RECV);
assert_true(efa_context->completion_flags & FI_MSG);
assert_true(efa_context->addr == addr);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
assert_int_equal(ret, 1);

Expand Down Expand Up @@ -973,20 +997,29 @@ void test_efa_cq_read_send_failure(struct efa_resource **state)
{
struct efa_resource *resource = *state;
struct efa_unit_test_buff send_buff;
struct efa_base_ep *base_ep;
struct efa_context *efa_context;
struct fi_cq_data_entry cq_entry;
struct fi_context2 ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_GENERAL_ERR,
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx);
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);

assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
fi_mr_desc(send_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);
jiaxiyan marked this conversation as resolved.
Show resolved Hide resolved
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);

base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
efa_context = (struct efa_context *) base_ep->qp->ibv_qp_ex->wr_id;
assert_true(efa_context->completion_flags & FI_SEND);
assert_true(efa_context->completion_flags & FI_MSG);
assert_true(efa_context->addr == addr);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
/* fi_cq_read() called efa_mock_ibv_start_poll_use_saved_send_wr(), which pulled one send_wr from g_ibv_submitted_wr_idv=_vec */
assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
Expand All @@ -1010,18 +1043,27 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state)
{
struct efa_resource *resource = *state;
struct efa_unit_test_buff recv_buff;
struct efa_base_ep *base_ep;
struct efa_context *efa_context;
struct fi_cq_data_entry cq_entry;
struct fi_context2 ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_GENERAL_ERR,
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx);
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);

ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
fi_mr_desc(recv_buff.mr), addr, NULL);
fi_mr_desc(recv_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);

base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
efa_context = (struct efa_context *) base_ep->efa_recv_wr_vec[base_ep->recv_wr_index].wr.wr_id;
assert_true(efa_context->completion_flags & FI_RECV);
assert_true(efa_context->completion_flags & FI_MSG);
assert_true(efa_context->addr == addr);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
assert_int_equal(ret, -FI_EAVAIL);

Expand Down
Loading