From d8b7b81a585c0938667082e00838053a29ecad6b Mon Sep 17 00:00:00 2001 From: code Date: Tue, 20 Aug 2024 07:28:01 +0800 Subject: [PATCH] refactor the grpc client to fix #116 (#132) --- .github/workflows/main.yml | 8 +- BUILD | 1 + cpp2sky/internal/BUILD | 6 - cpp2sky/internal/async_client.h | 141 +++++++---------- cpp2sky/internal/stream_builder.h | 56 ------- cpp2sky/tracer.h | 2 +- source/BUILD | 8 +- source/grpc_async_client_impl.cc | 244 ++++++++++++++++++------------ source/grpc_async_client_impl.h | 195 ++++++++++++++---------- source/tracer_impl.cc | 77 +++------- source/tracer_impl.h | 22 +-- source/utils/BUILD | 2 +- source/utils/buffer.h | 76 ++++++++++ source/utils/circular_buffer.h | 122 --------------- test/BUILD | 1 - test/buffer_test.cc | 149 ++++-------------- test/grpc_async_client_test.cc | 162 +++++++++++++++++--- test/mocks.h | 44 +----- test/tracer_test.cc | 15 +- test/tracing_context_test.cc | 2 +- 20 files changed, 621 insertions(+), 712 deletions(-) delete mode 100644 cpp2sky/internal/stream_builder.h create mode 100644 source/utils/buffer.h delete mode 100644 source/utils/circular_buffer.h diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1be4465..0794767 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,16 +27,16 @@ jobs: - uses: actions/checkout@v3 - name: Run bazel test with GCC c++11 run: | - bazel test --cxxopt=-std=c++0x //... + bazel test --test_output=all --cxxopt=-std=c++0x //... - name: Run bazel test with GCC c++17 run: | - bazel test --cxxopt=-std=c++17 //... + bazel test --test_output=all --cxxopt=-std=c++17 //... - name: Run bazel test with CLANG c++11 run: | - bazel test --config=clang --cxxopt=-std=c++0x //... + bazel test --test_output=all -c dbg --config=clang --cxxopt=-std=c++0x //... - name: Run bazel test with CLANG c++17 run: | - bazel test --config=clang --cxxopt=-std=c++17 //... + bazel test --test_output=all -c opt --config=clang --cxxopt=-std=c++17 //... - name: Install cmake dependencies and run cmake compile run: | sudo apt update diff --git a/BUILD b/BUILD index 7f60adb..3012d2e 100644 --- a/BUILD +++ b/BUILD @@ -9,5 +9,6 @@ refresh_compile_commands( "//cpp2sky/...": "", "//source/...": "", "//test/...": "", + "//example/...": "", }, ) diff --git a/cpp2sky/internal/BUILD b/cpp2sky/internal/BUILD index 397c85c..b68a5f4 100644 --- a/cpp2sky/internal/BUILD +++ b/cpp2sky/internal/BUILD @@ -23,12 +23,6 @@ cc_library( visibility = ["//visibility:public"], ) -cc_library( - name = "stream_builder_interface", - hdrs = ["stream_builder.h"], - visibility = ["//visibility:public"], -) - cc_library( name = "matcher_interface", hdrs = ["matcher.h"], diff --git a/cpp2sky/internal/async_client.h b/cpp2sky/internal/async_client.h index d9c8b59..c71c1c2 100644 --- a/cpp2sky/internal/async_client.h +++ b/cpp2sky/internal/async_client.h @@ -14,18 +14,19 @@ #pragma once -#include -#include -#include - +#include #include -#include "source/utils/circular_buffer.h" - -using google::protobuf::Message; +#include "google/protobuf/message.h" +#include "grpcpp/generic/generic_stub.h" +#include "grpcpp/grpcpp.h" +#include "language-agent/Tracing.pb.h" namespace cpp2sky { +/** + * Template base class for gRPC async client. + */ template class AsyncClient { public: @@ -37,108 +38,78 @@ class AsyncClient { virtual void sendMessage(RequestType message) = 0; /** - * Pending message queue reference. - */ - virtual CircularBuffer& pendingMessages() = 0; - - /** - * Start stream if there is no living stream. - */ - virtual void startStream() = 0; - - /** - * Completion queue. + * Reset the client. This should be called when the client is no longer + * needed. */ - virtual grpc::CompletionQueue& completionQueue() = 0; - - /** - * gRPC Stub - */ - virtual grpc::TemplatedGenericStub& stub() = 0; + virtual void resetClient() = 0; }; template using AsyncClientPtr = std::unique_ptr>; +/** + * Template base class for gRPC async stream. The stream is used to represent + * a single gRPC stream/request. + */ template class AsyncStream { public: virtual ~AsyncStream() = default; /** - * Send message. It will move the state from Init to Write. + * Send the specified protobuf message. */ virtual void sendMessage(RequestType message) = 0; }; -enum class StreamState : uint8_t { - Initialized = 0, - Ready = 1, - Idle = 2, - WriteDone = 3, - ReadDone = 4, +template +using AsyncStreamPtr = std::unique_ptr>; + +/** + * Tag for async operation. The callback should be called when the operation is + * done. + */ +struct AsyncEventTag { + std::function callback; }; +using AsyncEventTagPtr = std::unique_ptr; + +using GrpcClientContextPtr = std::unique_ptr; +using GrpcCompletionQueue = grpc::CompletionQueue; -class AsyncStreamCallback { +/** + * Factory for creating async stream. + */ +template +class AsyncStreamFactory { public: - /** - * Callback when stream ready event occured. - */ - virtual void onReady() = 0; + virtual ~AsyncStreamFactory() = default; - /** - * Callback when idle event occured. - */ - virtual void onIdle() = 0; + using StreamPtr = AsyncStreamPtr; + using GrpcStub = grpc::TemplatedGenericStub; - /** - * Callback when write done event occured. - */ - virtual void onWriteDone() = 0; + virtual StreamPtr createStream(GrpcClientContextPtr client_ctx, + GrpcStub& stub, GrpcCompletionQueue& cq, + AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) = 0; +}; - /** - * Callback when read done event occured. - */ - virtual void onReadDone() = 0; +template +using AsyncStreamFactoryPtr = + std::unique_ptr>; - /** - * Callback when stream had finished with arbitrary error. - */ - virtual void onStreamFinish() = 0; -}; +using TraceRequestType = skywalking::v3::SegmentObject; +using TraceResponseType = skywalking::v3::Commands; -struct StreamCallbackTag { - public: - void callback(bool stream_finished) { - if (stream_finished) { - callback_->onStreamFinish(); - return; - } - - switch (state_) { - case StreamState::Ready: - callback_->onReady(); - break; - case StreamState::WriteDone: - callback_->onWriteDone(); - break; - case StreamState::Idle: - callback_->onIdle(); - break; - case StreamState::ReadDone: - callback_->onReadDone(); - break; - default: - break; - } - } - - StreamState state_; - AsyncStreamCallback* callback_; -}; +using TraceAsyncStream = AsyncStream; +using TraceAsyncStreamPtr = AsyncStreamPtr; -template -using AsyncStreamSharedPtr = - std::shared_ptr>; +using TraceAsyncStreamFactory = + AsyncStreamFactory; +using TraceAsyncStreamFactoryPtr = + AsyncStreamFactoryPtr; + +using TraceAsyncClient = AsyncClient; +using TraceAsyncClientPtr = std::unique_ptr; } // namespace cpp2sky diff --git a/cpp2sky/internal/stream_builder.h b/cpp2sky/internal/stream_builder.h deleted file mode 100644 index da7e26e..0000000 --- a/cpp2sky/internal/stream_builder.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2021 SkyAPM - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include - -namespace cpp2sky { - -template -class ClientStreamingStreamBuilder { - public: - virtual ~ClientStreamingStreamBuilder() = default; - - /** - * Create async stream entity - */ - virtual AsyncStreamSharedPtr create( - AsyncClient& client, - std::condition_variable& cv) = 0; -}; - -template -using ClientStreamingStreamBuilderPtr = - std::unique_ptr>; - -template -class UnaryStreamBuilder { - public: - virtual ~UnaryStreamBuilder() = default; - - /** - * Create async stream entity - */ - virtual AsyncStreamSharedPtr create( - AsyncClient& client, RequestType request) = 0; -}; - -template -using UnaryStreamBuilderPtr = - std::unique_ptr>; - -} // namespace cpp2sky diff --git a/cpp2sky/tracer.h b/cpp2sky/tracer.h index 3eacd78..4cb43ff 100644 --- a/cpp2sky/tracer.h +++ b/cpp2sky/tracer.h @@ -40,6 +40,6 @@ class Tracer { using TracerPtr = std::unique_ptr; -TracerPtr createInsecureGrpcTracer(TracerConfig& cfg); +TracerPtr createInsecureGrpcTracer(const TracerConfig& cfg); } // namespace cpp2sky diff --git a/source/BUILD b/source/BUILD index c45126e..05b8ef3 100644 --- a/source/BUILD +++ b/source/BUILD @@ -4,22 +4,20 @@ cc_library( name = "cpp2sky_lib", srcs = [ "grpc_async_client_impl.cc", - "propagation_impl.cc", "tracer_impl.cc", - "tracing_context_impl.cc", ], hdrs = [ "grpc_async_client_impl.h", - "propagation_impl.h", "tracer_impl.h", - "tracing_context_impl.h", ], visibility = ["//visibility:public"], deps = [ + ":cpp2sky_data_lib", + "//cpp2sky:config_cc_proto", + "//cpp2sky:cpp2sky_data_interface", "//cpp2sky:cpp2sky_interface", "//cpp2sky/internal:async_client_interface", "//cpp2sky/internal:matcher_interface", - "//cpp2sky/internal:stream_builder_interface", "//source/matchers:suffix_matcher_lib", "//source/utils:util_lib", "@com_github_gabime_spdlog//:spdlog", diff --git a/source/grpc_async_client_impl.cc b/source/grpc_async_client_impl.cc index 42e6e0c..f4f7d6b 100644 --- a/source/grpc_async_client_impl.cc +++ b/source/grpc_async_client_impl.cc @@ -12,148 +12,198 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "grpc_async_client_impl.h" +#include "source/grpc_async_client_impl.h" + +#include #include +#include #include #include "absl/strings/string_view.h" #include "cpp2sky/exception.h" +#include "cpp2sky/internal/async_client.h" #include "spdlog/spdlog.h" namespace cpp2sky { namespace { -static constexpr absl::string_view authenticationKey = "authentication"; -} -using namespace spdlog; +static constexpr uint32_t MaxPendingMessagesSize = 1024; -GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient( - const std::string& address, grpc::CompletionQueue& cq, - ClientStreamingStreamBuilderPtr - factory, - std::shared_ptr cred) - : factory_(std::move(factory)), - cq_(cq), - stub_(grpc::CreateChannel(address, cred)) { - startStream(); -} +static std::string AuthenticationKey = "authentication"; + +static std::string TraceCollectMethod = "/TraceSegmentReportService/collect"; + +} // namespace + +using namespace spdlog; -GrpcAsyncSegmentReporterClient::~GrpcAsyncSegmentReporterClient() { - // It will wait until there is no drained messages with 5 second timeout. - if (stream_) { - std::unique_lock lck(mux_); - while (!pending_messages_.empty()) { - cv_.wait_for(lck, std::chrono::seconds(5)); - pending_messages_.clear(); +void EventLoopThread::gogo() { + while (true) { + void* got_tag{nullptr}; + bool ok{false}; + + // true if got an event from the queue or false + // if the queue is fully drained and is shutdown. + const bool status = cq_.Next(&got_tag, &ok); + if (!status) { + assert(got_tag == nullptr); + assert(!ok); + info("[Reporter] Completion queue is drained and is shutdown."); + break; } - } - resetStream(); -} + assert(got_tag != nullptr); -void GrpcAsyncSegmentReporterClient::sendMessage(TracerRequestType message) { - pending_messages_.push(message); + // The lifetime of the tag is managed by the caller. + auto* tag = static_cast(got_tag); + tag->callback(ok); + } +} - if (!stream_) { - info( - "[Reporter] No active stream, inserted message into pending message " - "queue. " - "pending message size: {}", - pending_messages_.size()); - return; +TraceAsyncStreamImpl::TraceAsyncStreamImpl(GrpcClientContextPtr client_ctx, + TraceGrpcStub& stub, + GrpcCompletionQueue& cq, + AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) + : client_ctx_(std::move(client_ctx)), + basic_event_tag_(basic_event_tag), + write_event_tag_(write_event_tag) { + if (client_ctx_ == nullptr) { + client_ctx_.reset(new grpc::ClientContext()); } - stream_->sendMessage(message); + request_writer_ = + stub.PrepareCall(client_ctx_.get(), TraceCollectMethod, &cq); + request_writer_->StartCall(reinterpret_cast(&basic_event_tag_)); } -void GrpcAsyncSegmentReporterClient::startStream() { - resetStream(); +void TraceAsyncStreamImpl::sendMessage(TraceRequestType message) { + request_writer_->Write(message, reinterpret_cast(&write_event_tag_)); +} - stream_ = factory_->create(*this, cv_); - info("[Reporter] Stream {} had created.", fmt::ptr(stream_.get())); +TraceAsyncStreamPtr TraceAsyncStreamFactoryImpl::createStream( + GrpcClientContextPtr client_ctx, TraceGrpcStub& stub, + GrpcCompletionQueue& cq, AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) { + return TraceAsyncStreamPtr{new TraceAsyncStreamImpl( + std::move(client_ctx), stub, cq, basic_event_tag, write_event_tag)}; } -void GrpcAsyncSegmentReporterClient::resetStream() { - if (stream_) { - info("[Reporter] Stream {} has destroyed.", fmt::ptr(stream_.get())); - stream_.reset(); - } +std::unique_ptr TraceAsyncClientImpl::createClient( + const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr factory, CredentialsSharedPtr cred) { + return std::unique_ptr{new TraceAsyncClientImpl( + address, token, std::move(factory), std::move(cred))}; } -GrpcAsyncSegmentReporterStream::GrpcAsyncSegmentReporterStream( - AsyncClient& client, - std::condition_variable& cv, const std::string& token) - : client_(client), cv_(cv) { - if (!token.empty()) { - ctx_.AddMetadata(authenticationKey.data(), token); - } +TraceAsyncClientImpl::TraceAsyncClientImpl(const std::string& address, + const std::string& token, + TraceAsyncStreamFactoryPtr factory, + CredentialsSharedPtr cred) + : token_(token), + stream_factory_(std::move(factory)), + stub_(grpc::CreateChannel(address, cred)) { + basic_event_tag_.callback = [this](bool ok) { + if (client_reset_) { + return; + } - // Ensure pending RPC will complete if connection to the server is not - // established first because of like server is not ready. This will queue - // pending RPCs and when connection has established, Connected tag will be - // sent to CompletionQueue. - ctx_.set_wait_for_ready(true); + if (ok) { + trace("[Reporter] Stream event success.", fmt::ptr(this)); - request_writer_ = client_.stub().PrepareCall( - &ctx_, "/TraceSegmentReportService/collect", &client_.completionQueue()); - request_writer_->StartCall(reinterpret_cast(&ready_)); -} + // Mark event loop as idle because the previous Write() or + // other operations are successful. + markEventLoopIdle(); -void GrpcAsyncSegmentReporterStream::sendMessage(TracerRequestType message) { - clearPendingMessage(); -} + sendMessageOnce(); + return; + } else { + trace("[Reporter] Stream event failure.", fmt::ptr(this)); -bool GrpcAsyncSegmentReporterStream::clearPendingMessage() { - if (state_ != StreamState::Idle || client_.pendingMessages().empty()) { - return false; - } - auto message = client_.pendingMessages().front(); - if (!message.has_value()) { - return false; + // Do not mark event loop as idle because the previous Write() + // or other operations are failed. The event loop should keep + // running to process the re-creation of the stream. + assert(event_loop_idle_.load() == false); + // Reset stream and try to create a new one. + startStream(); + } + }; + + write_event_tag_.callback = [this](bool ok) { + if (ok) { + trace("[Reporter] Stream {} message sending success.", fmt::ptr(this)); + messages_sent_++; + } else { + trace("[Reporter] Stream {} message sending failure.", fmt::ptr(this)); + messages_dropped_++; + } + // Delegate the event to basic_event_tag_ to trigger the next task or + // reset the stream. + basic_event_tag_.callback(ok); + }; + + // If the factory is not provided, use the default one. + if (stream_factory_ == nullptr) { + stream_factory_.reset(new TraceAsyncStreamFactoryImpl()); } - request_writer_->Write(message.value(), - reinterpret_cast(&write_done_)); - return true; + startStream(); } -void GrpcAsyncSegmentReporterStream::onReady() { - info("[Reporter] Stream ready"); +void TraceAsyncClientImpl::sendMessageOnce() { + bool expect_idle = true; + if (event_loop_idle_.compare_exchange_strong(expect_idle, false)) { + assert(active_stream_ != nullptr); + + auto opt_message = message_buffer_.pop_front(); + if (!opt_message.has_value()) { + // No message to send, mark event loop as idle. + markEventLoopIdle(); + return; + } - state_ = StreamState::Idle; - onIdle(); + active_stream_->sendMessage(std::move(opt_message).value()); + } } -void GrpcAsyncSegmentReporterStream::onIdle() { - info("[Reporter] Stream idleing"); +void TraceAsyncClientImpl::startStream() { + if (active_stream_ != nullptr) { + resetStream(); // Reset stream before creating a new one. + } - // Release pending messages which are inserted when stream is not ready - // to write. - if (!clearPendingMessage()) { - cv_.notify_all(); + // Create the unique client context for the new stream. + // Each stream should have its own context. + auto client_ctx = GrpcClientContextPtr{new grpc::ClientContext()}; + if (!token_.empty()) { + client_ctx->AddMetadata(AuthenticationKey, token_); } -} -void GrpcAsyncSegmentReporterStream::onWriteDone() { - info("[Reporter] Write finished"); + active_stream_ = stream_factory_->createStream( + std::move(client_ctx), stub_, event_loop_.cq_, basic_event_tag_, + write_event_tag_); - // Dequeue message after sending message finished. - // With this, messages which failed to sent never lost even if connection - // was closed. because pending messages with messages which failed to send - // will drained and resend another stream. - client_.pendingMessages().pop(); - state_ = StreamState::Idle; + info("[Reporter] Stream {} has created.", fmt::ptr(active_stream_.get())); +} - onIdle(); +void TraceAsyncClientImpl::resetStream() { + info("[Reporter] Stream {} has deleted.", fmt::ptr(active_stream_.get())); + active_stream_.reset(); } -AsyncStreamSharedPtr -GrpcAsyncSegmentReporterStreamBuilder::create( - AsyncClient& client, - std::condition_variable& cv) { - return std::make_shared(client, cv, token_); +void TraceAsyncClientImpl::sendMessage(TraceRequestType message) { + messages_total_++; + + const size_t pending = message_buffer_.size(); + if (pending > MaxPendingMessagesSize) { + info("[Reporter] pending message overflow and drop message"); + messages_dropped_++; + return; + } + message_buffer_.push_back(std::move(message)); + + sendMessageOnce(); } } // namespace cpp2sky diff --git a/source/grpc_async_client_impl.h b/source/grpc_async_client_impl.h index 07a6b73..251727b 100644 --- a/source/grpc_async_client_impl.h +++ b/source/grpc_async_client_impl.h @@ -14,120 +14,153 @@ #pragma once -#include -#include - -#include +#include #include #include #include +#include +#include #include "cpp2sky/config.pb.h" #include "cpp2sky/internal/async_client.h" -#include "cpp2sky/internal/stream_builder.h" #include "language-agent/Tracing.grpc.pb.h" -#include "language-agent/Tracing.pb.h" +#include "source/utils/buffer.h" namespace cpp2sky { -namespace { -static constexpr size_t pending_message_buffer_size = 1024; -} - -using TracerRequestType = skywalking::v3::SegmentObject; -using TracerResponseType = skywalking::v3::Commands; +using CredentialsSharedPtr = std::shared_ptr; -class GrpcAsyncSegmentReporterStream; +using TraceGrpcStub = + grpc::TemplatedGenericStub; +using TraceReaderWriter = + grpc::ClientAsyncReaderWriter; +using TraceReaderWriterPtr = std::unique_ptr; -class GrpcAsyncSegmentReporterClient final - : public AsyncClient { +class EventLoopThread { public: - GrpcAsyncSegmentReporterClient( - const std::string& address, grpc::CompletionQueue& cq, - ClientStreamingStreamBuilderPtr - factory, - std::shared_ptr cred); - ~GrpcAsyncSegmentReporterClient(); + EventLoopThread() : thread_([this] { this->gogo(); }) {} + ~EventLoopThread() { exit(); } - // AsyncClient - void sendMessage(TracerRequestType message) override; - CircularBuffer& pendingMessages() override { - return pending_messages_; - } - void startStream() override; - grpc::TemplatedGenericStub& stub() - override { - return stub_; - } - grpc::CompletionQueue& completionQueue() override { return cq_; } + grpc::CompletionQueue cq_; - size_t numOfMessages() { return pending_messages_.size(); } + void exit() { + if (!exited_) { + exited_ = true; + cq_.Shutdown(); + thread_.join(); + } + } private: - void resetStream(); + bool exited_{false}; + std::thread thread_; - std::string address_; - ClientStreamingStreamBuilderPtr - factory_; - grpc::CompletionQueue& cq_; - grpc::TemplatedGenericStub stub_; - AsyncStreamSharedPtr stream_; - CircularBuffer pending_messages_{ - pending_message_buffer_size}; - - std::mutex mux_; - std::condition_variable cv_; + void gogo(); }; -class GrpcAsyncSegmentReporterStream final - : public AsyncStream, - public AsyncStreamCallback { +class TraceAsyncStreamImpl : public TraceAsyncStream { public: - GrpcAsyncSegmentReporterStream( - AsyncClient& client, - std::condition_variable& cv, const std::string& token); + TraceAsyncStreamImpl(GrpcClientContextPtr client_ctx, TraceGrpcStub& stub, + GrpcCompletionQueue& cq, AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag); // AsyncStream - void sendMessage(TracerRequestType message) override; - - // AsyncStreamCallback - void onReady() override; - void onIdle() override; - void onWriteDone() override; - void onReadDone() override {} - void onStreamFinish() override { client_.startStream(); } + void sendMessage(TraceRequestType message) override; private: - bool clearPendingMessage(); + GrpcClientContextPtr client_ctx_; + TraceReaderWriterPtr request_writer_; - AsyncClient& client_; - TracerResponseType commands_; - grpc::ClientContext ctx_; - std::unique_ptr< - grpc::ClientAsyncReaderWriter> - request_writer_; - StreamState state_{StreamState::Initialized}; + AsyncEventTag& basic_event_tag_; + AsyncEventTag& write_event_tag_; +}; - StreamCallbackTag ready_{StreamState::Ready, this}; - StreamCallbackTag write_done_{StreamState::WriteDone, this}; +class TraceAsyncStreamFactoryImpl : public TraceAsyncStreamFactory { + public: + TraceAsyncStreamFactoryImpl() = default; - std::condition_variable& cv_; + TraceAsyncStreamPtr createStream(GrpcClientContextPtr client_ctx, + GrpcStub& stub, GrpcCompletionQueue& cq, + AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) override; }; -class GrpcAsyncSegmentReporterStreamBuilder final - : public ClientStreamingStreamBuilder { +class TraceAsyncClientImpl : public TraceAsyncClient { public: - explicit GrpcAsyncSegmentReporterStreamBuilder(const std::string& token) - : token_(token) {} + /** + * Create a new GrpcAsyncSegmentReporterClient. + * + * @param address The address of the server. + * @param token The optional token used to authenticate the client. + * If non-empty token is provided, the client will send the token + * to the server in the metadata. + * @param cred The credentials for creating the channel. + * @param factory The factory function to create the stream from the + * request writer and event tags. In most cases, the default factory + * should be used. + */ + static std::unique_ptr createClient( + const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr factory = nullptr, + CredentialsSharedPtr cred = grpc::InsecureChannelCredentials()); + + ~TraceAsyncClientImpl() override { + if (!client_reset_) { + resetClient(); + } + } + + // AsyncClient + void sendMessage(TraceRequestType message) override; + void resetClient() override { + // After this is called, no more events will be processed. + client_reset_ = true; + message_buffer_.clear(); + event_loop_.exit(); + resetStream(); + } - // ClientStreamingStreamBuilder - AsyncStreamSharedPtr create( - AsyncClient& client, - std::condition_variable& cv) override; + protected: + TraceAsyncClientImpl( + const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr factory = nullptr, + CredentialsSharedPtr cred = grpc::InsecureChannelCredentials()); - private: - std::string token_; + // Start or re-create the stream that used to send messages. + void startStream(); + void resetStream(); + void markEventLoopIdle() { event_loop_idle_.store(true); } + void sendMessageOnce(); + + const std::string token_; + TraceAsyncStreamFactoryPtr stream_factory_; + TraceGrpcStub stub_; + + // This may be operated by multiple threads. + std::atomic messages_total_{0}; + std::atomic messages_dropped_{0}; + std::atomic messages_sent_{0}; + + EventLoopThread event_loop_; + std::atomic client_reset_{false}; + + ValueBuffer message_buffer_; + + AsyncEventTag basic_event_tag_; + AsyncEventTag write_event_tag_; + + // The Write() of the stream could only be called once at a time + // until the previous Write() is finished (callback is called). + // Considering the complexity and the thread safety, we make sure + // that all operations on the stream are done one by one. + // This flag is used to indicate whether the event loop is idle + // before we perform the next operation on the stream. + // + // Initially the value is false because the event loop will be + // occupied by the first operation (startStream). + std::atomic event_loop_idle_{false}; + + TraceAsyncStreamPtr active_stream_; }; } // namespace cpp2sky diff --git a/source/tracer_impl.cc b/source/tracer_impl.cc index 6957f30..e7320c1 100644 --- a/source/tracer_impl.cc +++ b/source/tracer_impl.cc @@ -20,32 +20,27 @@ #include "cpp2sky/exception.h" #include "language-agent/ConfigurationDiscoveryService.pb.h" #include "matchers/suffix_matcher.h" +#include "source/grpc_async_client_impl.h" #include "spdlog/spdlog.h" namespace cpp2sky { -TracerImpl::TracerImpl(TracerConfig& config, - std::shared_ptr cred) - : config_(config), - evloop_thread_([this] { this->run(); }), - segment_factory_(config) { +using namespace spdlog; + +TracerImpl::TracerImpl(const TracerConfig& config, CredentialsSharedPtr cred) + : segment_factory_(config) { init(config, cred); } -TracerImpl::TracerImpl( - TracerConfig& config, - AsyncClientPtr reporter_client) - : config_(config), - reporter_client_(std::move(reporter_client)), - evloop_thread_([this] { this->run(); }), - segment_factory_(config) { +TracerImpl::TracerImpl(const TracerConfig& config, + TraceAsyncClientPtr async_client) + : async_client_(std::move(async_client)), segment_factory_(config) { init(config, nullptr); } TracerImpl::~TracerImpl() { - reporter_client_.reset(); - cq_.Shutdown(); - evloop_thread_.join(); + // Stop the reporter client. + async_client_->resetClient(); } TracingContextSharedPtr TracerImpl::newContext() { @@ -56,63 +51,39 @@ TracingContextSharedPtr TracerImpl::newContext(SpanContextSharedPtr span) { return segment_factory_.create(span); } -bool TracerImpl::report(TracingContextSharedPtr obj) { - if (!obj || !obj->readyToSend()) { +bool TracerImpl::report(TracingContextSharedPtr ctx) { + if (!ctx || !ctx->readyToSend()) { return false; } - for (const auto& op_name_matcher : op_name_matchers_) { - if (!obj->spans().empty() && - op_name_matcher->match(obj->spans().front()->operationName())) { + if (!ctx->spans().empty()) { + if (ignore_matcher_->match(ctx->spans().front()->operationName())) { return false; } } - reporter_client_->sendMessage(obj->createSegmentObject()); + async_client_->sendMessage(ctx->createSegmentObject()); return true; } -void TracerImpl::run() { - void* got_tag; - bool ok = false; - while (true) { - grpc::CompletionQueue::NextStatus status = cq_.AsyncNext( - &got_tag, &ok, gpr_time_from_nanos(0, GPR_CLOCK_REALTIME)); - switch (status) { - case grpc::CompletionQueue::TIMEOUT: - continue; - case grpc::CompletionQueue::SHUTDOWN: - return; - case grpc::CompletionQueue::GOT_EVENT: - break; - } - static_cast(got_tag)->callback(!ok); - } -} - -void TracerImpl::init(TracerConfig& config, - std::shared_ptr cred) { +void TracerImpl::init(const TracerConfig& config, CredentialsSharedPtr cred) { spdlog::set_level(spdlog::level::warn); - if (reporter_client_ == nullptr) { - if (config.protocol() == Protocol::GRPC) { - reporter_client_ = absl::make_unique( - config.address(), cq_, - absl::make_unique( - config.token()), - cred); - } else { - throw TracerException("REST is not supported."); + if (async_client_ == nullptr) { + if (config.protocol() != Protocol::GRPC) { + throw TracerException("Only GRPC is supported."); } + async_client_ = TraceAsyncClientImpl::createClient( + config.address(), config.token(), nullptr, std::move(cred)); } - op_name_matchers_.emplace_back(absl::make_unique( + ignore_matcher_.reset(new SuffixMatcher( std::vector(config.ignore_operation_name_suffix().begin(), config.ignore_operation_name_suffix().end()))); } -TracerPtr createInsecureGrpcTracer(TracerConfig& cfg) { - return absl::make_unique(cfg, grpc::InsecureChannelCredentials()); +TracerPtr createInsecureGrpcTracer(const TracerConfig& cfg) { + return TracerPtr{new TracerImpl(cfg, grpc::InsecureChannelCredentials())}; } } // namespace cpp2sky diff --git a/source/tracer_impl.h b/source/tracer_impl.h index 4f63f1f..febd577 100644 --- a/source/tracer_impl.h +++ b/source/tracer_impl.h @@ -34,29 +34,21 @@ using CdsResponse = skywalking::v3::Commands; class TracerImpl : public Tracer { public: - TracerImpl(TracerConfig& config, - std::shared_ptr cred); - TracerImpl( - TracerConfig& config, - AsyncClientPtr reporter_client); + TracerImpl(const TracerConfig& config, CredentialsSharedPtr credentials); + TracerImpl(const TracerConfig& config, TraceAsyncClientPtr async_client); ~TracerImpl(); TracingContextSharedPtr newContext() override; TracingContextSharedPtr newContext(SpanContextSharedPtr span) override; - bool report(TracingContextSharedPtr obj) override; + bool report(TracingContextSharedPtr ctx) override; private: - void init(TracerConfig& config, - std::shared_ptr cred); - void run(); - - TracerConfig config_; - AsyncClientPtr reporter_client_; - grpc::CompletionQueue cq_; - std::thread evloop_thread_; + void init(const TracerConfig& config, CredentialsSharedPtr cred); + + TraceAsyncClientPtr async_client_; TracingContextFactory segment_factory_; - std::list op_name_matchers_; + MatcherPtr ignore_matcher_; }; } // namespace cpp2sky diff --git a/source/utils/BUILD b/source/utils/BUILD index 6305d44..8a07d56 100644 --- a/source/utils/BUILD +++ b/source/utils/BUILD @@ -7,7 +7,7 @@ cc_library( ], hdrs = [ "base64.h", - "circular_buffer.h", + "buffer.h", "random_generator.h", "timer.h", ], diff --git a/source/utils/buffer.h b/source/utils/buffer.h new file mode 100644 index 0000000..e0f76d9 --- /dev/null +++ b/source/utils/buffer.h @@ -0,0 +1,76 @@ +// Copyright 2020 SkyAPM + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "absl/types/optional.h" + +namespace cpp2sky { + +template +class ValueBuffer { + public: + ValueBuffer() = default; + + absl::optional pop_front() { + std::unique_lock lock(mux_); + if (buf_.empty()) { + return absl::nullopt; + } + auto result = std::move(buf_.front()); + buf_.pop_front(); + return result; + } + + /** + * Insert new value. + */ + void push_back(Value value) { + std::unique_lock lock(mux_); + buf_.emplace_back(std::move(value)); + } + + /** + * Check whether buffer is empty or not. + */ + bool empty() const { + std::unique_lock lock(mux_); + return buf_.empty(); + } + + /** + * Get item count + */ + size_t size() const { + std::unique_lock lock(mux_); + return buf_.size(); + } + + /** + * Clear buffer + */ + void clear() { + std::unique_lock lock(mux_); + buf_.clear(); + } + + private: + std::deque buf_; + mutable std::mutex mux_; +}; + +} // namespace cpp2sky diff --git a/source/utils/circular_buffer.h b/source/utils/circular_buffer.h deleted file mode 100644 index 050cd13..0000000 --- a/source/utils/circular_buffer.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2020 SkyAPM - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "absl/types/optional.h" - -namespace cpp2sky { - -template -class CircularBuffer { - public: - CircularBuffer(size_t max_capacity) - : back_(max_capacity - 1), max_capacity_(max_capacity) {} - - // disable copy - CircularBuffer(const CircularBuffer&) = delete; - CircularBuffer& operator=(const CircularBuffer&) = delete; - - struct Buffer { - T value; - bool is_destroyed_; - }; - - /** - * Get value which inserted older than any other values. - * It will return nullopt if buffer is empty. - */ - absl::optional front() { - if (empty()) { - return absl::nullopt; - } - return buf_[front_].value; - } - - /** - * Delete oldest value. It won't delete actual data we can treat as logical - * deletion. - */ - void pop() { - std::unique_lock lock(mux_); - popInternal(); - } - - /** - * Insert new value. If the buffer has more than max_capacity, it will delete - * the oldest value. - */ - void push(T value) { - std::unique_lock lock(mux_); - if (buf_.size() < max_capacity_) { - buf_.emplace_back(Buffer{value, false}); - back_ = (back_ + 1) % max_capacity_; - ++item_count_; - return; - } - - back_ = (back_ + 1) % max_capacity_; - if (!buf_[back_].is_destroyed_) { - popInternal(); - } - buf_[back_] = Buffer{value, false}; - ++item_count_; - } - - /** - * Check whether buffer is empty or not. - */ - bool empty() { return item_count_ == 0; } - - /** - * Get item count - */ - size_t size() const { return item_count_; } - - /** - * Clear buffer - */ - void clear() { - buf_.clear(); - item_count_ = 0; - } - - // Used for test - size_t frontIdx() { return front_; } - size_t backIdx() { return back_; } - - private: - void popInternal() { - if (empty() || buf_[front_].is_destroyed_) { - return; - } - // Not to destroy actual data. - buf_[front_].is_destroyed_ = true; - --item_count_; - front_ = (front_ + 1) % max_capacity_; - } - - size_t front_ = 0; - size_t back_ = 0; - size_t max_capacity_; - size_t item_count_ = 0; - - std::deque buf_; - std::mutex mux_; -}; - -} // namespace cpp2sky diff --git a/test/BUILD b/test/BUILD index 1a3c133..1cba365 100644 --- a/test/BUILD +++ b/test/BUILD @@ -8,7 +8,6 @@ cc_library( deps = [ "//cpp2sky/internal:async_client_interface", "//cpp2sky/internal:random_generator_interface", - "//cpp2sky/internal:stream_builder_interface", "@com_google_googletest//:gtest_main", "@skywalking_data_collect_protocol//language-agent:tracing_protocol_cc_proto", ], diff --git a/test/buffer_test.cc b/test/buffer_test.cc index dceab9c..b1770f7 100644 --- a/test/buffer_test.cc +++ b/test/buffer_test.cc @@ -12,131 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "source/utils/buffer.h" + #include #include #include "absl/memory/memory.h" -#include "source/utils/circular_buffer.h" namespace cpp2sky { -class CircularBufferTest : public testing::Test { - protected: - void setup(size_t size) { - buf_ = absl::make_unique>(size); - } - - void evaluate(size_t expect_front, size_t expect_back, bool expect_empty) { - EXPECT_EQ(expect_front, buf_->frontIdx()); - EXPECT_EQ(expect_back, buf_->backIdx()); - EXPECT_EQ(expect_empty, buf_->empty()); - } - - void checkFront(int expect_value) { - auto a = buf_->front(); - ASSERT_TRUE(a.has_value()); - EXPECT_EQ(a.value(), expect_value); - } - - std::unique_ptr> buf_; -}; - -TEST_F(CircularBufferTest, Basic) { - setup(3); - for (auto i = 0; i < 1000; ++i) { - buf_->pop(); - } - - buf_->push(1); - buf_->push(2); - buf_->push(3); - evaluate(0, 2, false); - - buf_->push(4); - evaluate(1, 0, false); - - buf_->push(5); - buf_->push(6); - evaluate(0, 2, false); - - checkFront(4); - buf_->pop(); - evaluate(1, 2, false); - - checkFront(5); - buf_->pop(); - evaluate(2, 2, false); - - buf_->push(7); - evaluate(2, 0, false); - - checkFront(6); - buf_->pop(); - evaluate(0, 0, false); - - checkFront(7); - buf_->pop(); - // Return to Empty state - evaluate(1, 0, true); - - buf_->push(8); - evaluate(1, 1, false); - - buf_->push(9); - buf_->push(10); - buf_->push(11); - buf_->push(12); - - checkFront(10); - evaluate(0, 2, false); - - buf_->pop(); - buf_->pop(); - buf_->pop(); - - evaluate(0, 2, true); - - for (auto i = 0; i < 1000; ++i) { - buf_->pop(); - } - - evaluate(0, 2, true); -} - -TEST_F(CircularBufferTest, Basic2) { - setup(3); - - buf_->push(1); - buf_->pop(); - evaluate(1, 0, true); - - buf_->push(2); - evaluate(1, 1, false); - - buf_->push(3); - buf_->push(4); - - buf_->pop(); - - checkFront(3); - buf_->pop(); - checkFront(4); - buf_->pop(); - - buf_->pop(); -} - -TEST_F(CircularBufferTest, Basic3) { - setup(3); - - buf_->push(1); - buf_->push(2); - buf_->push(3); - buf_->pop(); - buf_->push(4); - buf_->push(5); - - evaluate(2, 1, false); +TEST(BufferTest, Basic) { + ValueBuffer buffer; + EXPECT_TRUE(buffer.empty()); + EXPECT_EQ(buffer.size(), 0); + + buffer.push_back(1); + EXPECT_FALSE(buffer.empty()); + EXPECT_EQ(buffer.size(), 1); + + buffer.push_back(2); + EXPECT_FALSE(buffer.empty()); + EXPECT_EQ(buffer.size(), 2); + + auto value = buffer.pop_front(); + EXPECT_TRUE(value.has_value()); + EXPECT_EQ(value.value(), 1); + EXPECT_FALSE(buffer.empty()); + EXPECT_EQ(buffer.size(), 1); + + value = buffer.pop_front(); + EXPECT_TRUE(value.has_value()); + EXPECT_EQ(value.value(), 2); + EXPECT_TRUE(buffer.empty()); + EXPECT_EQ(buffer.size(), 0); + + value = buffer.pop_front(); + EXPECT_FALSE(value.has_value()); + EXPECT_TRUE(buffer.empty()); + EXPECT_EQ(buffer.size(), 0); } } // namespace cpp2sky diff --git a/test/grpc_async_client_test.cc b/test/grpc_async_client_test.cc index 806f7c7..4016d6d 100644 --- a/test/grpc_async_client_test.cc +++ b/test/grpc_async_client_test.cc @@ -14,10 +14,13 @@ #include #include +#include +#include #include #include "absl/memory/memory.h" +#include "cpp2sky/internal/async_client.h" #include "language-agent/Tracing.pb.h" #include "source/grpc_async_client_impl.h" #include "test/mocks.h" @@ -26,36 +29,157 @@ namespace cpp2sky { using testing::_; -class GrpcAsyncSegmentReporterClientTest : public testing::Test { +struct TestStats { + TestStats(uint64_t total, uint64_t dropped, uint64_t sent) + : total_(total), dropped_(dropped), sent_(sent) { + pending_ = total_ - dropped_ - sent_; + } + + uint64_t total_{}; + uint64_t dropped_{}; + uint64_t sent_{}; + uint64_t pending_{}; +}; + +class TestTraceAsyncClient : public TraceAsyncClientImpl { + public: + TestTraceAsyncClient(const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr stream_factory, + CredentialsSharedPtr credentials) + : TraceAsyncClientImpl(address, token, std::move(stream_factory), + std::move(credentials)) {} + + TestStats getTestStats() const { + TestStats stats(messages_total_.load(), messages_dropped_.load(), + messages_sent_.load()); + return stats; + } + + void notifyWriteEvent(bool success) { write_event_tag_.callback(success); } + void notifyStartEvent(bool success) { basic_event_tag_.callback(success); } + + uint64_t bufferSize() const { return message_buffer_.size(); } +}; + +class TestTraceAsyncStreamFactory : public TraceAsyncStreamFactory { + public: + TestTraceAsyncStreamFactory(std::shared_ptr mock_stream) + : mock_stream_(mock_stream) {} + + class TestTraceAsyncStream : public TraceAsyncStream { + public: + TestTraceAsyncStream(std::shared_ptr mock_stream) + : mock_stream_(mock_stream) {} + void sendMessage(TraceRequestType message) override { + mock_stream_->sendMessage(std::move(message)); + } + std::shared_ptr mock_stream_; + }; + + TraceAsyncStreamPtr createStream(GrpcClientContextPtr, GrpcStub&, + GrpcCompletionQueue&, AsyncEventTag&, + AsyncEventTag&) override { + return TraceAsyncStreamPtr{new TestTraceAsyncStream(mock_stream_)}; + } + + std::shared_ptr mock_stream_; +}; + +class TraceAsyncClientImplTest : public testing::Test { public: - GrpcAsyncSegmentReporterClientTest() { - stream_ = std::make_shared< - MockAsyncStream>(); - factory_ = absl::make_unique>(stream_); - EXPECT_CALL(*factory_, create(_, _)); - - client_ = absl::make_unique( - address_, cq_, std::move(factory_), grpc::InsecureChannelCredentials()); + TraceAsyncClientImplTest() { + client_.reset(new TestTraceAsyncClient( + address_, token_, + TraceAsyncStreamFactoryPtr{ + new TestTraceAsyncStreamFactory(mock_stream_)}, + grpc::InsecureChannelCredentials())); + } + + ~TraceAsyncClientImplTest() { + client_->resetClient(); + client_.reset(); } protected: - grpc::CompletionQueue cq_; std::string address_{"localhost:50051"}; std::string token_{"token"}; - std::shared_ptr> - stream_; - std::unique_ptr< - MockClientStreamingStreamBuilder> - factory_; - std::unique_ptr client_; + std::shared_ptr mock_stream_ = + std::make_shared(); + + std::unique_ptr client_; }; -TEST_F(GrpcAsyncSegmentReporterClientTest, SendMessageTest) { +TEST_F(TraceAsyncClientImplTest, SendMessageTest) { skywalking::v3::SegmentObject fake_message; - EXPECT_CALL(*stream_, sendMessage(_)); + EXPECT_CALL(*mock_stream_, sendMessage(_)).Times(0); client_->sendMessage(fake_message); + + auto stats = client_->getTestStats(); + EXPECT_EQ(stats.total_, 1); + EXPECT_EQ(stats.dropped_, 0); + EXPECT_EQ(stats.sent_, 0); + EXPECT_EQ(stats.pending_, 1); + EXPECT_EQ(client_->bufferSize(), 1); + + client_->notifyStartEvent(false); + + sleep(1); // wait for the event loop to process the event. + + // The stream is not ready, the message still in the buffer. + stats = client_->getTestStats(); + EXPECT_EQ(stats.total_, 1); + EXPECT_EQ(stats.dropped_, 0); + EXPECT_EQ(stats.sent_, 0); + EXPECT_EQ(stats.pending_, 1); + EXPECT_EQ(client_->bufferSize(), 1); + + EXPECT_CALL(*mock_stream_, sendMessage(_)); + client_->notifyStartEvent(true); + sleep(1); // wait for the event loop to process the event. + + // The stream is ready, the message is popped and sent. + // But before the collback is called, the stats is not updated. + + stats = client_->getTestStats(); + EXPECT_EQ(stats.total_, 1); + EXPECT_EQ(stats.dropped_, 0); + EXPECT_EQ(stats.sent_, 0); + EXPECT_EQ(stats.pending_, 1); + EXPECT_EQ(client_->bufferSize(), 0); + + client_->notifyWriteEvent(true); + sleep(1); // wait for the event loop to process the event. + + // The message is sent successfully. + stats = client_->getTestStats(); + EXPECT_EQ(stats.total_, 1); + EXPECT_EQ(stats.dropped_, 0); + EXPECT_EQ(stats.sent_, 1); + EXPECT_EQ(stats.pending_, 0); + EXPECT_EQ(client_->bufferSize(), 0); + + // Send another message. This time the stream is ready and + // previous message is sent successfully. So the new message + // should be sent immediately. + EXPECT_CALL(*mock_stream_, sendMessage(_)); + client_->sendMessage(fake_message); + sleep(1); // wait for the event loop to process the event. + + stats = client_->getTestStats(); + EXPECT_EQ(stats.total_, 2); + EXPECT_EQ(stats.dropped_, 0); + EXPECT_EQ(stats.sent_, 1); + EXPECT_EQ(stats.pending_, 1); + + client_->notifyWriteEvent(true); + sleep(1); // wait for the event loop to process the event. + + stats = client_->getTestStats(); + EXPECT_EQ(stats.total_, 2); + EXPECT_EQ(stats.dropped_, 0); + EXPECT_EQ(stats.sent_, 2); + EXPECT_EQ(stats.pending_, 0); } } // namespace cpp2sky diff --git a/test/mocks.h b/test/mocks.h index 5c48e4a..d9f81c1 100644 --- a/test/mocks.h +++ b/test/mocks.h @@ -17,11 +17,8 @@ #include #include -#include - #include "cpp2sky/internal/async_client.h" #include "cpp2sky/internal/random_generator.h" -#include "cpp2sky/internal/stream_builder.h" using testing::_; using testing::Return; @@ -34,46 +31,15 @@ class MockRandomGenerator : public RandomGenerator { MOCK_METHOD(std::string, uuid, ()); }; -template -class MockAsyncStream : public AsyncStream { - public: - MOCK_METHOD(void, sendMessage, (RequestType)); - MOCK_METHOD(void, onIdle, ()); - MOCK_METHOD(void, onWriteDone, ()); - MOCK_METHOD(void, onReady, ()); -}; - -template -class MockAsyncClient : public AsyncClient { +class MockTraceAsyncStream : public TraceAsyncStream { public: - using GenericStub = grpc::TemplatedGenericStub; - - MOCK_METHOD(void, sendMessage, (RequestType)); - MOCK_METHOD(GenericStub&, stub, ()); - MOCK_METHOD(CircularBuffer&, pendingMessages, ()); - MOCK_METHOD(void, startStream, ()); - MOCK_METHOD(grpc::CompletionQueue&, completionQueue, ()); + MOCK_METHOD(void, sendMessage, (TraceRequestType)); }; -template -class MockClientStreamingStreamBuilder final - : public ClientStreamingStreamBuilder { +class MockTraceAsyncClient : public TraceAsyncClient { public: - using AsyncClientType = AsyncClient; - using AsyncStreamSharedPtrType = - AsyncStreamSharedPtr; - - MockClientStreamingStreamBuilder( - std::shared_ptr> stream) - : stream_(stream) { - ON_CALL(*this, create(_, _)).WillByDefault(Return(stream_)); - } - - MOCK_METHOD(AsyncStreamSharedPtrType, create, - (AsyncClientType&, std::condition_variable&)); - - private: - std::shared_ptr> stream_; + MOCK_METHOD(void, sendMessage, (TraceRequestType)); + MOCK_METHOD(void, resetClient, ()); }; } // namespace cpp2sky diff --git a/test/tracer_test.cc b/test/tracer_test.cc index bab1f91..462eb07 100644 --- a/test/tracer_test.cc +++ b/test/tracer_test.cc @@ -17,6 +17,7 @@ #include #include "cpp2sky/config.pb.h" +#include "cpp2sky/internal/async_client.h" #include "mocks.h" #include "source/tracer_impl.h" @@ -26,9 +27,8 @@ TEST(TracerTest, MatchedOpShouldIgnored) { TracerConfig config; *config.add_ignore_operation_name_suffix() = "/ignored"; - TracerImpl tracer( - config, absl::make_unique< - MockAsyncClient>()); + TracerImpl tracer(config, TraceAsyncClientPtr{ + new testing::NiceMock()}); auto context = tracer.newContext(); auto span = context->createEntrySpan(); @@ -41,9 +41,8 @@ TEST(TracerTest, MatchedOpShouldIgnored) { TEST(TracerTest, NotClosedSpanExists) { TracerConfig config; - TracerImpl tracer( - config, absl::make_unique< - MockAsyncClient>()); + TracerImpl tracer(config, TraceAsyncClientPtr{ + new testing::NiceMock()}); auto context = tracer.newContext(); auto span = context->createEntrySpan(); @@ -55,8 +54,8 @@ TEST(TracerTest, NotClosedSpanExists) { TEST(TracerTest, Success) { TracerConfig config; - auto mock_reporter = absl::make_unique< - MockAsyncClient>(); + auto mock_reporter = std::unique_ptr{ + new testing::NiceMock()}; EXPECT_CALL(*mock_reporter, sendMessage(_)); TracerImpl tracer(config, std::move(mock_reporter)); diff --git a/test/tracing_context_test.cc b/test/tracing_context_test.cc index 9114670..b802b23 100644 --- a/test/tracing_context_test.cc +++ b/test/tracing_context_test.cc @@ -44,7 +44,7 @@ class TracingContextTest : public testing::Test { span_ctx_ = std::make_shared(sample_ctx); span_ext_ctx_ = std::make_shared("1"); - factory_ = absl::make_unique(config_); + factory_.reset(new TracingContextFactory(config_)); } protected: