From a1c1ccafa16cfdc155519fa38f9a5b782a1a5571 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 10 May 2023 12:29:34 -0400 Subject: [PATCH] [SUPPORT] Fix RingBuffer ReadWithCallback (#14743) This PR bugfixes ring buffer ReadWithCallback when the callback send function read part of the data. Also backported one robustness fix from unity. Testcases are added --- src/support/ring_buffer.h | 15 +++++- tests/cpp/support/ring_buffer_test.cc | 68 +++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/cpp/support/ring_buffer_test.cc diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index 1c6a6f8b4350..866c9c4424e0 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -24,6 +24,8 @@ #ifndef TVM_SUPPORT_RING_BUFFER_H_ #define TVM_SUPPORT_RING_BUFFER_H_ +#include + #include #include #include @@ -61,6 +63,9 @@ class RingBuffer { if (head_ptr_ + bytes_available_ > old_size) { // copy the ring overflow part into the tail. size_t ncopy = head_ptr_ + bytes_available_ - old_size; + if (old_size + ncopy > ring_.size()) { + ring_.resize(old_size + ncopy); + } memcpy(&ring_[0] + old_size, &ring_[0], ncopy); } } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { @@ -101,6 +106,9 @@ class RingBuffer { } head_ptr_ = (head_ptr_ + size) % ring_.size(); bytes_available_ -= size; + if (bytes_available_ == 0) { + head_ptr_ = 0; + } } /*! * \brief Read data from buffer with and put them to non-blocking send function. @@ -115,12 +123,15 @@ class RingBuffer { ICHECK_NE(size, 0U); size_t ncopy = std::min(size, ring_.size() - head_ptr_); size_t nsend = fsend(&ring_[0] + head_ptr_, ncopy); - bytes_available_ -= nsend; if (ncopy == nsend && ncopy < size) { size_t nsend2 = fsend(&ring_[0], size - ncopy); - bytes_available_ -= nsend2; nsend += nsend2; } + head_ptr_ = (head_ptr_ + nsend) % ring_.size(); + bytes_available_ -= nsend; + if (bytes_available_ == 0) { + head_ptr_ = 0; + } return nsend; } /*! diff --git a/tests/cpp/support/ring_buffer_test.cc b/tests/cpp/support/ring_buffer_test.cc new file mode 100644 index 000000000000..9b78b2767731 --- /dev/null +++ b/tests/cpp/support/ring_buffer_test.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/support/ring_buffer.h" + +#include + +namespace tvm { +namespace support { +namespace { + +TEST(RingBuffer, ReadWrite) { + RingBuffer buffer; + std::vector data = {1, 2, 3, 4}; + std::vector output; + + buffer.Write(data.data(), data.size() * 4); + ASSERT_EQ(buffer.bytes_available(), data.size() * 4); + + output.resize(4); + buffer.Read(output.data(), data.size() * 4); + + for (size_t i = 0; i < output.size(); ++i) { + ASSERT_EQ(output[i], data[i]); + } +} + +TEST(RingBuffer, ReadWithCallback) { + RingBuffer buffer; + std::vector data = {1, 2, 3, 4}; + std::vector output; + + buffer.Write(data.data(), data.size() * 4); + + auto callback0 = [](const char* data, size_t size) -> size_t { + const int* iptr = reinterpret_cast(data); + ICHECK_EQ(iptr[0], 1); + ICHECK_EQ(iptr[1], 2); + return size; + }; + buffer.ReadWithCallback(callback0, 2 * sizeof(int)); + auto callback1 = [](const char* data, size_t size) -> size_t { + const int* iptr = reinterpret_cast(data); + ICHECK_EQ(iptr[0], 3); + ICHECK_EQ(iptr[1], 4); + return size; + }; + buffer.ReadWithCallback(callback1, 2 * sizeof(int)); +} +} // namespace +} // namespace support +} // namespace tvm