diff --git a/src/engine_uring.cpp b/src/engine_uring.cpp index bbccba8..c243525 100644 --- a/src/engine_uring.cpp +++ b/src/engine_uring.cpp @@ -89,6 +89,13 @@ struct connection_t; struct engine_t; struct automata_t; +enum class uring_op_tag_t : std::uintptr_t { + uring_acpt_tag_k = 0, + uring_recv_tag_k, + uring_send_tag_k, + uring_stat_tag_k, // Max 2 bits +}; + enum class stage_t { waiting_to_accept_k = 0, expecting_reception_k, @@ -102,6 +109,7 @@ struct completed_event_t { connection_t* connection_ptr{}; stage_t stage{}; int result{}; + uring_op_tag_t type{}; }; class alignas(align_k) mutex_t { @@ -250,6 +258,7 @@ struct automata_t { connection_t& connection; stage_t completed_stage{}; int completed_result{}; + uring_op_tag_t type{}; void operator()() noexcept; bool is_corrupted() const noexcept { return completed_result == -EPIPE || completed_result == -EBADF; } @@ -471,6 +480,7 @@ void ucall_take_call(ucall_server_t server, uint16_t thread_idx) { *completed.connection_ptr, completed.stage, completed.result, + completed.type, }; // If everything is fine, let automata work in its normal regime. @@ -728,7 +738,7 @@ void automata_t::parse_and_raise_request() noexcept { auto parsed_request = std::get(parsed_request_or_error); scratch.is_http = request.size() != parsed_request.body.size(); - scratch.dynamic_packet = parsed_request.body; + scratch.dynamic_packet = {parsed_request.body.data(), parsed_request.json_length}; if (scratch.dynamic_packet.size() > ram_page_size_k) { sjd::parser parser; if (parser.allocate(scratch.dynamic_packet.size(), scratch.dynamic_packet.size() / 2) != sj::SUCCESS) @@ -755,7 +765,9 @@ template std::size_t engine_t::pop_completed(complete ++passed; if (!uring_cqe->user_data) continue; - events[completed].connection_ptr = (connection_t*)uring_cqe->user_data; + + events[completed].connection_ptr = (connection_t*)(uring_cqe->user_data & ~0x3); + events[completed].type = static_cast(uring_cqe->user_data & 0x3); events[completed].stage = events[completed].connection_ptr->stage; events[completed].result = uring_cqe->res; ++completed; @@ -790,7 +802,8 @@ bool engine_t::consider_accepting_new_connection() noexcept { uring_sqe = io_uring_get_sqe(&uring); io_uring_prep_accept_direct(uring_sqe, socket, &connection.client_address, &connection.client_address_len, 0, IORING_FILE_INDEX_ALLOC); - io_uring_sqe_set_data(uring_sqe, &connection); + io_uring_sqe_set_data(uring_sqe, (void*)(static_cast(uring_op_tag_t::uring_acpt_tag_k) | + (std::uintptr_t)(&connection))); // Accepting new connections can be time-less. // io_uring_sqe_set_flags(uring_sqe, IOSQE_IO_LINK); @@ -820,7 +833,8 @@ void engine_t::submit_stats_heartbeat() noexcept { uring_sqe = io_uring_get_sqe(&uring); io_uring_prep_timeout(uring_sqe, &connection.next_wakeup, 0, 0); - io_uring_sqe_set_data(uring_sqe, &connection); + io_uring_sqe_set_data(uring_sqe, (void*)(static_cast(uring_op_tag_t::uring_stat_tag_k) | + (std::uintptr_t)(&connection))); uring_result = io_uring_submit(&uring); submission_mutex.unlock(); } @@ -892,7 +906,8 @@ void automata_t::send_next() noexcept { uring_sqe->flags |= IOSQE_FIXED_FILE; uring_sqe->buf_index = engine.connections.offset_of(connection) * 2u + 1u; } - io_uring_sqe_set_data(uring_sqe, &connection); + io_uring_sqe_set_data(uring_sqe, (void*)(static_cast(uring_op_tag_t::uring_send_tag_k) | + (std::uintptr_t)(&connection))); io_uring_sqe_set_flags(uring_sqe, 0); uring_result = io_uring_submit(&engine.uring); engine.submission_mutex.unlock(); @@ -918,7 +933,8 @@ void automata_t::receive_next() noexcept { uring_sqe = io_uring_get_sqe(&engine.uring); io_uring_prep_read_fixed(uring_sqe, int(connection.descriptor), (void*)pipes.next_input_address(), pipes.next_input_length(), 0, engine.connections.offset_of(connection) * 2u); - io_uring_sqe_set_data(uring_sqe, &connection); + io_uring_sqe_set_data(uring_sqe, (void*)(static_cast(uring_op_tag_t::uring_recv_tag_k) | + (std::uintptr_t)(&connection))); io_uring_sqe_set_flags(uring_sqe, IOSQE_IO_LINK); // More than other operations this depends on the information coming from the client. @@ -937,11 +953,15 @@ void automata_t::receive_next() noexcept { void automata_t::operator()() noexcept { if (is_corrupted()) - return close_gracefully(); + if (connection.stage != stage_t::waiting_to_close_k) + return close_gracefully(); switch (connection.stage) { case stage_t::waiting_to_accept_k: + if (type != uring_op_tag_t::uring_acpt_tag_k) { + return; + } if (completed_result == -ECANCELED) { engine.release_connection(connection); @@ -959,6 +979,9 @@ void automata_t::operator()() noexcept { case stage_t::expecting_reception_k: + if (type != uring_op_tag_t::uring_recv_tag_k) { + return; + } // From documentation: // > If used, the timeout specified in the command will cancel the linked command, // > unless the linked command completes before the timeout. The timeout will complete diff --git a/src/helpers/parse.hpp b/src/helpers/parse.hpp index 3990a43..35c4dc5 100644 --- a/src/helpers/parse.hpp +++ b/src/helpers/parse.hpp @@ -78,10 +78,12 @@ inline std::variant find_callback(named_callb bool id_invalid = (id.is_double() && !id.is_int64() && !id.is_uint64()) || id.is_object() || id.is_array(); if (id_invalid) return default_error_t{-32600, "The request must have integer or string id."}; + sj::simdjson_result method = doc["method"]; bool method_invalid = !method.is_string(); if (method_invalid) return default_error_t{-32600, "The method must be a string."}; + sj::simdjson_result params = doc["params"]; bool params_present_and_invalid = !params.is_array() && !params.is_object() && params.error() == sj::SUCCESS; if (params_present_and_invalid) @@ -117,6 +119,7 @@ struct parsed_request_t { std::string_view content_type{}; std::string_view content_length{}; std::string_view body{}; + std::size_t json_length; }; /** @@ -166,8 +169,13 @@ inline std::variant split_body_headers(std::s if (pos == std::string_view::npos) return default_error_t{-32700, "Invalid JSON was received by the server."}; req.body = body.substr(pos + 4); - } else + auto res = std::from_chars(req.content_length.begin(), req.content_length.end(), req.json_length); + if (res.ec == std::errc::invalid_argument) + return default_error_t{-32700, "Invalid JSON was received by the server."}; + } else { + req.json_length = body.size(); req.body = body; + } return req; }