diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index f6d50cc936..eb2a7019a6 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -946,15 +946,18 @@ Response Serialization::deserializeResponse(std::istream& is) { auto requestId = su::deserialize(is); auto errOrResult = su::deserialize>(is); + auto clientId = su::deserialize>(is); - return std::holds_alternative(errOrResult) ? Response{requestId, std::get(errOrResult)} - : Response{requestId, std::get(errOrResult)}; + return std::holds_alternative(errOrResult) + ? Response{requestId, std::get(errOrResult), clientId} + : Response{requestId, std::get(errOrResult), clientId}; } void Serialization::serialize(Response const& response, std::ostream& os) { su::serialize(response.mImpl->mRequestId, os); su::serialize(response.mImpl->mErrOrResult, os); + su::serialize(response.mImpl->mClientId, os); } size_t Serialization::serializedSize(Response const& response) @@ -962,6 +965,7 @@ size_t Serialization::serializedSize(Response const& response) size_t totalSize = 0; totalSize += su::serializedSize(response.mImpl->mRequestId); totalSize += su::serializedSize(response.mImpl->mErrOrResult); + totalSize += su::serializedSize(response.mImpl->mClientId); return totalSize; } diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 1d618c86da..bd7aacfcb6 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -160,6 +160,7 @@ void compareResponse(texec::Response res, texec::Response res2) { compareResult(res.getResult(), res2.getResult()); } + EXPECT_EQ(res.getClientId(), res2.getClientId()); } template @@ -428,11 +429,15 @@ TEST(SerializeUtilsTest, ResultResponse) auto val = texec::Response(1, "my error msg"); testSerializeDeserialize(val); } + { + auto val = texec::Response(1, "my error msg", 2); + testSerializeDeserialize(val); + } } TEST(SerializeUtilsTest, VectorResponses) { - int numResponses = 10; + int numResponses = 15; std::vector responsesIn; for (int i = 0; i < numResponses; ++i) { @@ -443,11 +448,16 @@ TEST(SerializeUtilsTest, VectorResponses) std::nullopt, std::vector{texec::FinishReason::kEND_ID}}; responsesIn.emplace_back(i, res); } - else + else if (i < 10) { std::string errMsg = "my_err_msg" + std::to_string(i); responsesIn.emplace_back(i, errMsg); } + else + { + std::string errMsg = "my_err_msg" + std::to_string(i); + responsesIn.emplace_back(i, errMsg, i + 1); + } } auto buffer = texec::Serialization::serialize(responsesIn);