diff --git a/.github/actions/regression-tests/action.yml b/.github/actions/regression-tests/action.yml index bafa6a1b7999..f45db013468f 100644 --- a/.github/actions/regression-tests/action.yml +++ b/.github/actions/regression-tests/action.yml @@ -26,7 +26,7 @@ runs: # timeout-minutes: 20 steps: - name: Run PyTests - id: first + id: main shell: bash run: | ls -l ${GITHUB_WORKSPACE}/ @@ -37,7 +37,7 @@ runs: export DRAGONFLY_PATH="${GITHUB_WORKSPACE}/${{inputs.build-folder-name}}/${{inputs.dfly-executable}}" export UBSAN_OPTIONS=print_stacktrace=1:halt_on_error=1 # to crash on errors - timeout 20m pytest -m "${{inputs.filter}}" --durations=10 --color=yes --json-report --json-report-file=report.json dragonfly --ignore=dragonfly/replication_test.py --log-cli-level=INFO || code=$? + timeout 40m pytest -m "${{inputs.filter}}" --durations=10 --color=yes --json-report --json-report-file=report.json dragonfly --log-cli-level=INFO || code=$? # timeout returns 124 if we exceeded the timeout duration if [[ $code -eq 124 ]]; then @@ -50,32 +50,6 @@ runs: exit 1 fi - - name: Run PyTests replication test - id: second - if: ${{ inputs.run-only-on-ubuntu-latest == 'true' || (inputs.run-only-on-ubuntu-latest == 'false' && matrix.runner == 'ubuntu-latest') }} - shell: bash - run: | - echo "Running PyTests replication test" - cd ${GITHUB_WORKSPACE}/tests - # used by PyTests - export DRAGONFLY_PATH="${GITHUB_WORKSPACE}/${{inputs.build-folder-name}}/${{inputs.dfly-executable}}" - - - timeout 20m pytest -m "${{inputs.filter}}" --durations=10 --color=yes --json-report \ - --json-report-file=rep1_report.json dragonfly/replication_test.py --log-cli-level=INFO \ - --df alsologtostderr $1 $2 || code=$? - - # timeout returns 124 if we exceeded the timeout duration - if [[ $code -eq 124 ]]; then - echo "TIMEDOUT=1">> "$GITHUB_OUTPUT" - exit 1 - fi - - # when a test fails in pytest it returns 1 but there are other return codes as well so we just check if the code is non zero - if [[ $code -ne 0 ]]; then - exit 1 - fi - - name: Print last log on timeout if: failure() shell: bash @@ -106,14 +80,7 @@ runs: } cd ${GITHUB_WORKSPACE}/tests failed_tests="" - # The order in of if is important, and expected to be the oposite order of the pytest run. - # As github runner will not run the next step if the pytest failed, we start from the last - # report file and if exist we get the failed test from the pytest run, if there are any. - if [ -f rep2_report.json ]; then - failed_tests=$(get_failed_tests rep2_report.json) - elif [ -f rep1_report.json ]; then - failed_tests=$(get_failed_tests rep1_report.json) - elif [ -f report.json ]; then + if [ -f report.json ]; then failed_tests=$(get_failed_tests report.json) fi diff --git a/.github/workflows/regression-tests.yml b/.github/workflows/regression-tests.yml index 8c5dc0cde5a7..bdfd74bb6fcf 100644 --- a/.github/workflows/regression-tests.yml +++ b/.github/workflows/regression-tests.yml @@ -41,7 +41,6 @@ jobs: with: dfly-executable: dragonfly gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} - run-only-on-ubuntu-latest: false build-folder-name: build # This expression serves as a ternary operator, i.e. if the condition holds it returns # 'not NON_EXISTING_MARK' otherwise not opt_only. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c75f24d0ca1a..03fbad63ff98 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -131,7 +131,6 @@ jobs: with: dfly-executable: dragonfly-x86_64 gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} - run-only-on-ubuntu-latest: true build-folder-name: ${{ env.RELEASE_DIR }} - name: Save artifacts run: | diff --git a/src/core/bloom.cc b/src/core/bloom.cc index ca55fa19f071..179cd19b5aaa 100644 --- a/src/core/bloom.cc +++ b/src/core/bloom.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include "base/logging.h" diff --git a/src/core/search/ast_expr.cc b/src/core/search/ast_expr.cc index b65a34d0bf43..767e8797877f 100644 --- a/src/core/search/ast_expr.cc +++ b/src/core/search/ast_expr.cc @@ -58,12 +58,13 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) { } AstKnnNode::AstKnnNode(uint32_t limit, std::string_view field, OwnedFtVector vec, - std::string_view score_alias) + std::string_view score_alias, std::optional ef_runtime) : filter{nullptr}, limit{limit}, field{field.substr(1)}, vec{std::move(vec)}, - score_alias{score_alias} { + score_alias{score_alias}, + ef_runtime{ef_runtime} { } AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) { @@ -72,3 +73,9 @@ AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) { } } // namespace dfly::search + +namespace std { +ostream& operator<<(ostream& os, optional o) { + return os; +} +} // namespace std diff --git a/src/core/search/ast_expr.h b/src/core/search/ast_expr.h index 2d6e8bdf40e7..f06b212849ed 100644 --- a/src/core/search/ast_expr.h +++ b/src/core/search/ast_expr.h @@ -74,7 +74,8 @@ struct AstTagsNode { struct AstKnnNode { AstKnnNode() = default; AstKnnNode(uint32_t limit, std::string_view field, OwnedFtVector vec, - std::string_view score_alias); + std::string_view score_alias, std::optional ef_runtime); + AstKnnNode(AstNode&& sub, AstKnnNode&& self); friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) { @@ -86,6 +87,7 @@ struct AstKnnNode { std::string field; OwnedFtVector vec; std::string score_alias; + std::optional ef_runtime; }; struct AstSortNode { @@ -108,6 +110,11 @@ struct AstNode : public NodeVariants { const NodeVariants& Variant() const& { return *this; } + + // Aggregations: KNN, SORTBY. They reorder result sets and optionally reduce them. + bool IsAggregation() const { + return std::holds_alternative(Variant()); + } }; using AstExpr = AstNode; @@ -115,4 +122,6 @@ using AstExpr = AstNode; } // namespace search } // namespace dfly -namespace std {} // namespace std +namespace std { +ostream& operator<<(ostream& os, optional o); +} diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index 9573f2c037cc..747860da7022 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -195,6 +195,9 @@ const float* FlatVectorIndex::Get(DocId doc) const { } struct HnswlibAdapter { + // Default setting of hnswlib/hnswalg + constexpr static size_t kDefaultEfRuntime = 10; + HnswlibAdapter(const SchemaField::VectorParams& params) : space_{MakeSpace(params.dim, params.sim)}, world_{GetSpacePtr(), params.capacity, @@ -214,11 +217,13 @@ struct HnswlibAdapter { world_.markDelete(id); } - vector> Knn(float* target, size_t k) { + vector> Knn(float* target, size_t k, std::optional ef) { + world_.setEf(ef.value_or(kDefaultEfRuntime)); return QueueToVec(world_.searchKnn(target, k)); } - vector> Knn(float* target, size_t k, const vector& allowed) { + vector> Knn(float* target, size_t k, std::optional ef, + const vector& allowed) { struct BinsearchFilter : hnswlib::BaseFilterFunctor { virtual bool operator()(hnswlib::labeltype id) { return binary_search(allowed->begin(), allowed->end(), id); @@ -229,6 +234,7 @@ struct HnswlibAdapter { const vector* allowed; }; + world_.setEf(ef.value_or(kDefaultEfRuntime)); BinsearchFilter filter{&allowed}; return QueueToVec(world_.searchKnn(target, k, &filter)); } @@ -276,12 +282,14 @@ void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { adapter_->Add(ptr.get(), id); } -std::vector> HnswVectorIndex::Knn(float* target, size_t k) const { - return adapter_->Knn(target, k); +std::vector> HnswVectorIndex::Knn(float* target, size_t k, + std::optional ef) const { + return adapter_->Knn(target, k, ef); } std::vector> HnswVectorIndex::Knn(float* target, size_t k, + std::optional ef, const std::vector& allowed) const { - return adapter_->Knn(target, k, allowed); + return adapter_->Knn(target, k, ef, allowed); } void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { diff --git a/src/core/search/indices.h b/src/core/search/indices.h index 4d20333460ba..c94944a2143c 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -136,8 +136,8 @@ struct HnswVectorIndex : public BaseVectorIndex { void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; - std::vector> Knn(float* target, size_t k) const; - std::vector> Knn(float* target, size_t k, + std::vector> Knn(float* target, size_t k, std::optional ef) const; + std::vector> Knn(float* target, size_t k, std::optional ef, const std::vector& allowed) const; private: diff --git a/src/core/search/lexer.lex b/src/core/search/lexer.lex index 76e002c006d8..5dee656a03ec 100644 --- a/src/core/search/lexer.lex +++ b/src/core/search/lexer.lex @@ -64,6 +64,7 @@ term_char [_]|\w "|" return Parser::make_OR_OP (loc()); "KNN" return Parser::make_KNN (loc()); "AS" return Parser::make_AS (loc()); +"EF_RUNTIME" return Parser::make_EF_RUNTIME (loc()); [0-9]{1,9} return make_UINT32(matched_view(), loc()); [+-]?(([0-9]*[.])?[0-9]+|inf) return make_DOUBLE(matched_view(), loc()); diff --git a/src/core/search/parser.y b/src/core/search/parser.y index 906312ab8d98..7a3cce413f26 100644 --- a/src/core/search/parser.y +++ b/src/core/search/parser.y @@ -34,6 +34,7 @@ #define yylex driver->scanner()->Lex using namespace std; + } %parse-param { QueryDriver *driver } @@ -46,18 +47,19 @@ using namespace std; %define api.token.prefix {TOK_} %token - LPAREN "(" - RPAREN ")" - STAR "*" - ARROW "=>" - COLON ":" - LBRACKET "[" - RBRACKET "]" - LCURLBR "{" - RCURLBR "}" - OR_OP "|" - KNN "KNN" - AS "AS" + LPAREN "(" + RPAREN ")" + STAR "*" + ARROW "=>" + COLON ":" + LBRACKET "[" + RBRACKET "]" + LCURLBR "{" + RCURLBR "}" + OR_OP "|" + KNN "KNN" + AS "AS" + EF_RUNTIME "EF_RUNTIME" ; %token AND_OP @@ -81,6 +83,7 @@ using namespace std; %nterm knn_query %nterm opt_knn_alias +%nterm > opt_ef_runtime %printer { yyo << $$; } <*>; @@ -93,13 +96,17 @@ final_query: { driver->Set(AstKnnNode(std::move($1), std::move($3))); } knn_query: - LBRACKET KNN UINT32 FIELD TERM opt_knn_alias RBRACKET - { $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); } + LBRACKET KNN UINT32 FIELD TERM opt_knn_alias opt_ef_runtime RBRACKET + { $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6, $7); } opt_knn_alias: AS TERM { $$ = std::move($2); } | { $$ = std::string{}; } +opt_ef_runtime: + /* empty */ { $$ = std::nullopt; } + | EF_RUNTIME UINT32 { $$ = $2; } + filter: search_expr { $$ = std::move($1); } | STAR { $$ = AstStarNode(); } @@ -174,5 +181,5 @@ tag_list: void dfly::search::Parser::error(const location_type& l, const string& m) { - cerr << l << ": " << m << '\n'; + driver->Error(l, m); } diff --git a/src/core/search/query_driver.cc b/src/core/search/query_driver.cc index 03904e2bd5ad..2422be10e3b1 100644 --- a/src/core/search/query_driver.cc +++ b/src/core/search/query_driver.cc @@ -18,6 +18,10 @@ void QueryDriver::ResetScanner() { scanner_->SetParams(params_); } +void QueryDriver::Error(const Parser::location_type& loc, std::string_view msg) { + LOG(ERROR) << "Parse error " << loc << ": " << msg; +} + } // namespace search } // namespace dfly diff --git a/src/core/search/query_driver.h b/src/core/search/query_driver.h index bef9ded96d7c..ca5401868a42 100644 --- a/src/core/search/query_driver.h +++ b/src/core/search/query_driver.h @@ -52,6 +52,8 @@ class QueryDriver { return scanner_.get(); } + void Error(const Parser::location_type& loc, std::string_view msg); + public: Parser::location_type location; diff --git a/src/core/search/search.cc b/src/core/search/search.cc index ca01c3916f7e..3e192a6224f8 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -353,9 +353,10 @@ struct BasicSearch { void SearchKnnHnsw(HnswVectorIndex* vec_index, const AstKnnNode& knn, IndexResult&& sub_results) { if (indices_->GetAllDocs().size() == sub_results.Size()) - knn_distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit); + knn_distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime); else - knn_distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, sub_results.Take()); + knn_distances_ = + vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime, sub_results.Take()); } // [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit` @@ -420,6 +421,7 @@ struct BasicSearch { profile_builder_ ? make_optional(profile_builder_->Take()) : nullopt; size_t total = result.Size(); + return SearchResult{total, max(total, preagg_total_), result.Take(limit_), @@ -428,6 +430,7 @@ struct BasicSearch { std::move(error_)}; } + private: const FieldIndices* indices_; size_t limit_; @@ -622,4 +625,8 @@ void SearchAlgorithm::EnableProfiling() { profiling_enabled_ = true; } +bool SearchAlgorithm::IsProfilingEnabled() const { + return profiling_enabled_; +} + } // namespace dfly::search diff --git a/src/core/search/search.h b/src/core/search/search.h index 5e8b14c949a5..7387444af0e4 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -144,6 +144,7 @@ class SearchAlgorithm { std::optional HasAggregation() const; void EnableProfiling(); + bool IsProfilingEnabled() const; private: bool profiling_enabled_ = false; diff --git a/src/core/search/search_parser_test.cc b/src/core/search/search_parser_test.cc index 5ef9944babed..032aecc8553f 100644 --- a/src/core/search/search_parser_test.cc +++ b/src/core/search/search_parser_test.cc @@ -149,4 +149,24 @@ TEST_F(SearchParserTest, KNN) { NEXT_TOK(TOK_LBRACKET); } +TEST_F(SearchParserTest, KNNfull) { + SetInput("*=>[KNN 1 @vector field_vec AS vec_sort EF_RUNTIME 15]"); + NEXT_TOK(TOK_STAR); + NEXT_TOK(TOK_ARROW); + NEXT_TOK(TOK_LBRACKET); + + NEXT_TOK(TOK_KNN); + NEXT_EQ(TOK_UINT32, uint32_t, 1); + NEXT_TOK(TOK_FIELD); + NEXT_TOK(TOK_TERM); + + NEXT_TOK(TOK_AS); + NEXT_EQ(TOK_TERM, string, "vec_sort"); + + NEXT_TOK(TOK_EF_RUNTIME); + NEXT_EQ(TOK_UINT32, uint32_t, 15); + + NEXT_TOK(TOK_RBRACKET); +} + } // namespace dfly::search diff --git a/src/core/sorted_map.cc b/src/core/sorted_map.cc index d0dc060bf783..70294feca228 100644 --- a/src/core/sorted_map.cc +++ b/src/core/sorted_map.cc @@ -198,7 +198,16 @@ unsigned char* ZzlInsert(unsigned char* zl, sds ele, double score) { return zzlInsertAt(zl, NULL, ele, score); } -int SortedMap::DfImpl::ScoreSdsPolicy::KeyCompareTo::operator()(ScoreSds a, ScoreSds b) const { +SortedMap::SortedMap(PMR_NS::memory_resource* mr) + : score_map(new ScoreMap(mr)), score_tree(new ScoreTree(mr)) { +} + +SortedMap::~SortedMap() { + delete score_tree; + delete score_map; +} + +int SortedMap::ScoreSdsPolicy::KeyCompareTo::operator()(ScoreSds a, ScoreSds b) const { sds sdsa = (sds)(uint64_t(a) & kSdsMask); sds sdsb = (sds)(uint64_t(b) & kSdsMask); @@ -224,7 +233,7 @@ int SortedMap::DfImpl::ScoreSdsPolicy::KeyCompareTo::operator()(ScoreSds a, Scor return sdscmp(sdsa, sdsb); } -int SortedMap::DfImpl::Add(double score, sds ele, int in_flags, int* out_flags, double* newscore) { +int SortedMap::Add(double score, sds ele, int in_flags, int* out_flags, double* newscore) { // does not take ownership over ele. DCHECK(!isnan(score)); @@ -272,7 +281,7 @@ int SortedMap::DfImpl::Add(double score, sds ele, int in_flags, int* out_flags, return 1; } -optional SortedMap::DfImpl::GetScore(sds ele) const { +optional SortedMap::GetScore(sds ele) const { ScoreSds obj = score_map->FindObj(ele); if (obj != nullptr) { return GetObjScore(obj); @@ -281,19 +290,8 @@ optional SortedMap::DfImpl::GetScore(sds ele) const { return std::nullopt; } -void SortedMap::DfImpl::Init(PMR_NS::memory_resource* mr) { - score_map = new ScoreMap(mr); - score_tree = new ScoreTree(mr); -} - -void SortedMap::DfImpl::Free() { - DVLOG(1) << "Freeing SortedMap"; - delete score_tree; - delete score_map; -} - // Takes ownership over ele. -bool SortedMap::DfImpl::Insert(double score, sds ele) { +bool SortedMap::Insert(double score, sds ele) { DVLOG(1) << "Inserting " << ele << " with score " << score; auto [newk, added] = score_map->AddOrUpdate(string_view{ele, sdslen(ele)}, score); @@ -306,7 +304,7 @@ bool SortedMap::DfImpl::Insert(double score, sds ele) { return true; } -optional SortedMap::DfImpl::GetRank(sds ele, bool reverse) const { +optional SortedMap::GetRank(sds ele, bool reverse) const { ScoreSds obj = score_map->FindObj(ele); if (obj == nullptr) return std::nullopt; @@ -316,8 +314,8 @@ optional SortedMap::DfImpl::GetRank(sds ele, bool reverse) const { return reverse ? score_map->UpperBoundSize() - *rank - 1 : *rank; } -SortedMap::ScoredArray SortedMap::DfImpl::GetRange(const zrangespec& range, unsigned offset, - unsigned limit, bool reverse) const { +SortedMap::ScoredArray SortedMap::GetRange(const zrangespec& range, unsigned offset, unsigned limit, + bool reverse) const { ScoredArray arr; if (score_tree->Size() <= offset || limit == 0) return arr; @@ -387,8 +385,8 @@ SortedMap::ScoredArray SortedMap::DfImpl::GetRange(const zrangespec& range, unsi return arr; } -SortedMap::ScoredArray SortedMap::DfImpl::GetLexRange(const zlexrangespec& range, unsigned offset, - unsigned limit, bool reverse) const { +SortedMap::ScoredArray SortedMap::GetLexRange(const zlexrangespec& range, unsigned offset, + unsigned limit, bool reverse) const { if (score_tree->Size() <= offset || limit == 0) return {}; @@ -459,7 +457,7 @@ SortedMap::ScoredArray SortedMap::DfImpl::GetLexRange(const zlexrangespec& range return arr; } -uint8_t* SortedMap::DfImpl::ToListPack() const { +uint8_t* SortedMap::ToListPack() const { uint8_t* lp = lpNew(0); score_tree->Iterate(0, UINT32_MAX, [&](ScoreSds ele) { @@ -470,7 +468,7 @@ uint8_t* SortedMap::DfImpl::ToListPack() const { return lp; } -bool SortedMap::DfImpl::Delete(sds ele) { +bool SortedMap::Delete(sds ele) { ScoreSds obj = score_map->FindObj(ele); if (obj == nullptr) return false; @@ -480,17 +478,17 @@ bool SortedMap::DfImpl::Delete(sds ele) { return true; } -size_t SortedMap::DfImpl::MallocSize() const { +size_t SortedMap::MallocSize() const { // TODO: add malloc used to BPTree. return score_map->SetMallocUsed() + score_map->ObjMallocUsed() + score_tree->NodeCount() * 256; } -bool SortedMap::DfImpl::Reserve(size_t sz) { +bool SortedMap::Reserve(size_t sz) { score_map->Reserve(sz); return true; } -size_t SortedMap::DfImpl::DeleteRangeByRank(unsigned start, unsigned end) { +size_t SortedMap::DeleteRangeByRank(unsigned start, unsigned end) { DCHECK_LE(start, end); DCHECK_LT(end, score_tree->Size()); @@ -510,7 +508,7 @@ size_t SortedMap::DfImpl::DeleteRangeByRank(unsigned start, unsigned end) { return end - start + 1; } -size_t SortedMap::DfImpl::DeleteRangeByScore(const zrangespec& range) { +size_t SortedMap::DeleteRangeByScore(const zrangespec& range) { char buf[16] = {0}; size_t deleted = 0; @@ -539,7 +537,7 @@ size_t SortedMap::DfImpl::DeleteRangeByScore(const zrangespec& range) { return deleted; } -size_t SortedMap::DfImpl::DeleteRangeByLex(const zlexrangespec& range) { +size_t SortedMap::DeleteRangeByLex(const zlexrangespec& range) { if (score_tree->Size() == 0) return 0; @@ -574,7 +572,7 @@ size_t SortedMap::DfImpl::DeleteRangeByLex(const zlexrangespec& range) { return deleted; } -SortedMap::ScoredArray SortedMap::DfImpl::PopTopScores(unsigned count, bool reverse) { +SortedMap::ScoredArray SortedMap::PopTopScores(unsigned count, bool reverse) { DCHECK_EQ(score_map->UpperBoundSize(), score_tree->Size()); size_t sz = score_map->UpperBoundSize(); @@ -608,7 +606,7 @@ SortedMap::ScoredArray SortedMap::DfImpl::PopTopScores(unsigned count, bool reve return res; } -size_t SortedMap::DfImpl::Count(const zrangespec& range) const { +size_t SortedMap::Count(const zrangespec& range) const { DCHECK_LE(range.min, range.max); if (score_tree->Size() == 0) @@ -654,7 +652,7 @@ size_t SortedMap::DfImpl::Count(const zrangespec& range) const { return max_rank < min_rank ? 0 : max_rank - min_rank + 1; } -size_t SortedMap::DfImpl::LexCount(const zlexrangespec& range) const { +size_t SortedMap::LexCount(const zlexrangespec& range) const { if (score_tree->Size() == 0) return 0; @@ -696,8 +694,8 @@ size_t SortedMap::DfImpl::LexCount(const zlexrangespec& range) const { return max_rank < min_rank ? 0 : max_rank - min_rank + 1; } -bool SortedMap::DfImpl::Iterate(unsigned start_rank, unsigned len, bool reverse, - absl::FunctionRef cb) const { +bool SortedMap::Iterate(unsigned start_rank, unsigned len, bool reverse, + absl::FunctionRef cb) const { DCHECK_GT(len, 0u); unsigned end_rank = start_rank + len - 1; bool success; @@ -712,8 +710,8 @@ bool SortedMap::DfImpl::Iterate(unsigned start_rank, unsigned len, bool reverse, return success; } -uint64_t SortedMap::DfImpl::Scan(uint64_t cursor, - absl::FunctionRef cb) const { +uint64_t SortedMap::Scan(uint64_t cursor, + absl::FunctionRef cb) const { auto scan_cb = [&cb](const void* obj) { sds ele = (sds)obj; cb(string_view{ele, sdslen(ele)}, GetObjScore(obj)); @@ -722,17 +720,6 @@ uint64_t SortedMap::DfImpl::Scan(uint64_t cursor, return this->score_map->Scan(cursor, std::move(scan_cb)); } -/***************************************************************************/ -/* SortedMap */ -/***************************************************************************/ -SortedMap::SortedMap(PMR_NS::memory_resource* mr) : impl_(DfImpl()) { - std::visit(Overload{[mr](auto& impl) { impl.Init(mr); }}, impl_); -} - -SortedMap::~SortedMap() { - std::visit([](auto& impl) { impl.Free(); }, impl_); -} - // taken from zsetConvert SortedMap* SortedMap::FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp) { uint8_t* zl = (uint8_t*)lp; diff --git a/src/core/sorted_map.h b/src/core/sorted_map.h index 2e516987d901..c1e028f7df75 100644 --- a/src/core/sorted_map.h +++ b/src/core/sorted_map.h @@ -23,10 +23,6 @@ namespace dfly { namespace detail { -template struct Overload : Ts... { using Ts::operator()...; }; - -template Overload(Ts...) -> Overload; - /** * @brief SortedMap is a sorted map implementation based on zset.h. It holds unique strings that * are ordered by score and lexicographically. The score is a double value and has higher priority. @@ -37,177 +33,62 @@ class SortedMap { public: using ScoredMember = std::pair; using ScoredArray = std::vector; + using ScoreSds = void*; SortedMap(PMR_NS::memory_resource* res); - SortedMap(const SortedMap&) = delete; - SortedMap& operator=(const SortedMap&) = delete; - ~SortedMap(); - // The ownership for the returned SortedMap stays with the caller, and must be freed via - // placement delete and then res->deallocate(). - static SortedMap* FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp); - - size_t Size() const { - return std::visit(Overload{[](const auto& impl) { return impl.Size(); }}, impl_); - } - - bool Reserve(size_t sz) { - return std::visit(Overload{[&](auto& impl) { return impl.Reserve(sz); }}, impl_); - } - - // Interface equivalent to zsetAdd. - // Does not take ownership over ele string. - // Returns 1 if succeeded, false if the final score became invalid due to the update. - // newscore is set to the new score of the element only if in_flags contains ZADD_IN_INCR. - int Add(double score, sds ele, int in_flags, int* out_flags, double* newscore) { - return std::visit( - Overload{[&](auto& impl) { return impl.Add(score, ele, in_flags, out_flags, newscore); }}, - impl_); - } - - // Takes ownership over member. - bool Insert(double score, sds member) { - return std::visit(Overload{[&](auto& impl) { return impl.Insert(score, member); }}, impl_); - } - - uint8_t* ToListPack() const { - return std::visit(Overload{[](const auto& impl) { return impl.ToListPack(); }}, impl_); - } - - size_t MallocSize() const { - return std::visit(Overload{[](const auto& impl) { return impl.MallocSize(); }}, impl_); - } - - uint64_t Scan(uint64_t cursor, absl::FunctionRef cb) const { - return std::visit([&](const auto& impl) { return impl.Scan(cursor, cb); }, impl_); - } - - size_t DeleteRangeByRank(unsigned start, unsigned end) { - return std::visit(Overload{[&](auto& impl) { return impl.DeleteRangeByRank(start, end); }}, - impl_); - } - - size_t DeleteRangeByScore(const zrangespec& range) { - return std::visit(Overload{[&](auto& impl) { return impl.DeleteRangeByScore(range); }}, impl_); - } + SortedMap(const SortedMap&) = delete; + SortedMap& operator=(const SortedMap&) = delete; - size_t DeleteRangeByLex(const zlexrangespec& range) { - return std::visit(Overload{[&](auto& impl) { return impl.DeleteRangeByLex(range); }}, impl_); - } + struct ScoreSdsPolicy { + using KeyT = ScoreSds; - // returns true if the element was deleted. - bool Delete(sds ele) { - return std::visit(Overload{[&](auto& impl) { return impl.Delete(ele); }}, impl_); - } + struct KeyCompareTo { + int operator()(KeyT a, KeyT b) const; + }; + }; - std::optional GetScore(sds ele) const { - return std::visit(Overload{[&](const auto& impl) { return impl.GetScore(ele); }}, impl_); - } + bool Reserve(size_t sz); + int Add(double score, sds ele, int in_flags, int* out_flags, double* newscore); + bool Insert(double score, sds member); + bool Delete(sds ele); - std::optional GetRank(sds ele, bool reverse) const { - return std::visit(Overload{[&](const auto& impl) { return impl.GetRank(ele, reverse); }}, - impl_); + size_t Size() const { + return score_map->UpperBoundSize(); } - ScoredArray GetRange(const zrangespec& range, unsigned offset, unsigned limit, - bool reverse) const { - return std::visit( - Overload{[&](const auto& impl) { return impl.GetRange(range, offset, limit, reverse); }}, - impl_); - } + size_t MallocSize() const; - ScoredArray GetLexRange(const zlexrangespec& range, unsigned offset, unsigned limit, - bool reverse) const { - return std::visit( - Overload{[&](const auto& impl) { return impl.GetLexRange(range, offset, limit, reverse); }}, - impl_); - } + size_t DeleteRangeByRank(unsigned start, unsigned end); + size_t DeleteRangeByScore(const zrangespec& range); + size_t DeleteRangeByLex(const zlexrangespec& range); - ScoredArray PopTopScores(unsigned count, bool reverse) { - return std::visit(Overload{[&](auto& impl) { return impl.PopTopScores(count, reverse); }}, - impl_); - } + ScoredArray PopTopScores(unsigned count, bool reverse); - size_t Count(const zrangespec& range) const { - return std::visit(Overload{[&](const auto& impl) { return impl.Count(range); }}, impl_); - } + std::optional GetScore(sds ele) const; + std::optional GetRank(sds ele, bool reverse) const; + ScoredArray GetRange(const zrangespec& r, unsigned offs, unsigned len, bool rev) const; + ScoredArray GetLexRange(const zlexrangespec& r, unsigned o, unsigned l, bool rev) const; - size_t LexCount(const zlexrangespec& range) const { - return std::visit(Overload{[&](const auto& impl) { return impl.LexCount(range); }}, impl_); - } + size_t Count(const zrangespec& range) const; + size_t LexCount(const zlexrangespec& range) const; // Runs cb for each element in the range [start_rank, start_rank + len). // Stops iteration if cb returns false. Returns false in this case. bool Iterate(unsigned start_rank, unsigned len, bool reverse, - absl::FunctionRef cb) const { - return std::visit([&](const auto& impl) { return impl.Iterate(start_rank, len, reverse, cb); }, - impl_); - } - - private: - struct DfImpl { - ScoreMap* score_map = nullptr; - using ScoreSds = void*; - - struct ScoreSdsPolicy { - using KeyT = ScoreSds; - - struct KeyCompareTo { - int operator()(KeyT a, KeyT b) const; - }; - }; - - using ScoreTree = BPTree; - ScoreTree* score_tree = nullptr; // just a stub for now. - - void Init(PMR_NS::memory_resource* mr); - - void Free(); - - int Add(double score, sds ele, int in_flags, int* out_flags, double* newscore); + absl::FunctionRef cb) const; - bool Insert(double score, sds member); + uint64_t Scan(uint64_t cursor, absl::FunctionRef cb) const; - bool Delete(sds ele); - - size_t Size() const { - return score_map->UpperBoundSize(); - } - - size_t MallocSize() const; - - bool Reserve(size_t sz); - - size_t DeleteRangeByRank(unsigned start, unsigned end); - - size_t DeleteRangeByScore(const zrangespec& range); - - size_t DeleteRangeByLex(const zlexrangespec& range); - - ScoredArray PopTopScores(unsigned count, bool reverse); - - uint8_t* ToListPack() const; - - std::optional GetScore(sds ele) const; - std::optional GetRank(sds ele, bool reverse) const; - - ScoredArray GetRange(const zrangespec& r, unsigned offs, unsigned len, bool rev) const; - ScoredArray GetLexRange(const zlexrangespec& r, unsigned o, unsigned l, bool rev) const; - - size_t Count(const zrangespec& range) const; - size_t LexCount(const zlexrangespec& range) const; - - // Runs cb for each element in the range [start_rank, start_rank + len). - // Stops iteration if cb returns false. Returns false in this case. - bool Iterate(unsigned start_rank, unsigned len, bool reverse, - absl::FunctionRef cb) const; + uint8_t* ToListPack() const; + static SortedMap* FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp); - uint64_t Scan(uint64_t cursor, absl::FunctionRef cb) const; - }; + private: + using ScoreTree = BPTree; - // TODO: remove this variant and get rid of wrapper class - std::variant impl_; + ScoreMap* score_map = nullptr; + ScoreTree* score_tree = nullptr; // just a stub for now. }; // Used by CompactObject. diff --git a/src/facade/error.h b/src/facade/error.h index 3eaf060cfc84..c6aded2b55a0 100644 --- a/src/facade/error.h +++ b/src/facade/error.h @@ -33,6 +33,7 @@ extern const char kInvalidNumericResult[]; extern const char kClusterNotConfigured[]; extern const char kLoadingErr[]; extern const char kUndeclaredKeyErr[]; +extern const char kInvalidDumpValueErr[]; extern const char kSyntaxErrType[]; extern const char kScriptErrType[]; diff --git a/src/facade/facade.cc b/src/facade/facade.cc index 1748c1cf4061..fce39a55cb84 100644 --- a/src/facade/facade.cc +++ b/src/facade/facade.cc @@ -95,6 +95,7 @@ const char kInvalidNumericResult[] = "result is not a number"; const char kClusterNotConfigured[] = "Cluster is not yet configured"; const char kLoadingErr[] = "-LOADING Dragonfly is loading the dataset in memory"; const char kUndeclaredKeyErr[] = "script tried accessing undeclared key"; +const char kInvalidDumpValueErr[] = "DUMP payload version or checksum are wrong"; const char kSyntaxErrType[] = "syntax_error"; const char kScriptErrType[] = "script_error"; diff --git a/src/facade/facade_types.h b/src/facade/facade_types.h index f567ccb18bef..877047941068 100644 --- a/src/facade/facade_types.h +++ b/src/facade/facade_types.h @@ -163,7 +163,7 @@ struct ErrorReply { std::string_view kind = {}) // to resolve ambiguity of constructors above : message{std::string_view{msg}}, kind{kind} { } - explicit ErrorReply(OpStatus status) : message{}, kind{}, status{status} { + ErrorReply(OpStatus status) : message{}, kind{}, status{status} { } std::string_view ToSv() const { diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 9337c7d871e3..70621aa3d470 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -79,8 +79,8 @@ class SinkReplyBuilder { } virtual void SendError(std::string_view str, std::string_view type = {}) = 0; // MC and Redis - virtual void SendError(ErrorReply error); virtual void SendError(OpStatus status); + void SendError(ErrorReply error); virtual void SendStored() = 0; // Reply for set commands. virtual void SendSetSkipped() = 0; diff --git a/src/facade/reply_capture.cc b/src/facade/reply_capture.cc index c81e9f5ec168..c286511fe0bc 100644 --- a/src/facade/reply_capture.cc +++ b/src/facade/reply_capture.cc @@ -4,6 +4,7 @@ #include "facade/reply_capture.h" #include "base/logging.h" +#include "facade/conn_context.h" #include "reply_capture.h" #define SKIP_LESS(needed) \ @@ -21,11 +22,6 @@ void CapturingReplyBuilder::SendError(std::string_view str, std::string_view typ Capture(Error{str, type}); } -void CapturingReplyBuilder::SendError(ErrorReply error) { - SKIP_LESS(ReplyMode::ONLY_ERR); - Capture(Error{error.ToSv(), error.kind}); -} - void CapturingReplyBuilder::SendMGetResponse(MGetResponse resp) { SKIP_LESS(ReplyMode::FULL); Capture(std::move(resp)); @@ -140,6 +136,16 @@ void CapturingReplyBuilder::CollapseFilledCollections() { } } +CapturingReplyBuilder::ScopeCapture::ScopeCapture(CapturingReplyBuilder* crb, + ConnectionContext* cntx) + : cntx_{cntx} { + old_ = cntx->Inject(crb); +} + +CapturingReplyBuilder::ScopeCapture::~ScopeCapture() { + cntx_->Inject(old_); +} + CapturingReplyBuilder::CollectionPayload::CollectionPayload(unsigned len, CollectionType type) : len{len}, type{type}, arr{} { arr.reserve(type == MAP ? len * 2 : len); diff --git a/src/facade/reply_capture.h b/src/facade/reply_capture.h index 7fe2843d23a7..7ce56e5eb534 100644 --- a/src/facade/reply_capture.h +++ b/src/facade/reply_capture.h @@ -14,6 +14,7 @@ namespace facade { +class ConnectionContext; struct CaptureVisitor; // CapturingReplyBuilder allows capturing replies and retrieveing them with Take(). @@ -24,12 +25,12 @@ class CapturingReplyBuilder : public RedisReplyBuilder { public: void SendError(std::string_view str, std::string_view type = {}) override; - void SendError(ErrorReply error) override; void SendMGetResponse(MGetResponse resp) override; // SendStored -> SendSimpleString("OK") // SendSetSkipped -> SendNull() void SendError(OpStatus status) override; + using RedisReplyBuilder::SendError; void SendNullArray() override; void SendEmptyArray() override; @@ -67,6 +68,16 @@ class CapturingReplyBuilder : public RedisReplyBuilder { struct SimpleString : public std::string {}; // SendSimpleString struct BulkString : public std::string {}; // SendBulkString + public: + struct ScopeCapture { + ScopeCapture(CapturingReplyBuilder* crb, ConnectionContext* cntx); + ~ScopeCapture(); + + private: + SinkReplyBuilder* old_; + ConnectionContext* cntx_; + }; + CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL) : RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} { } diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 852f84bb3018..518a16997197 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -15,7 +15,7 @@ set_property(SOURCE dfly_main.cc APPEND PROPERTY COMPILE_DEFINITIONS if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") SET(TX_LINUX_SRCS tiering/disk_storage.cc tiering/op_manager.cc tiering/small_bins.cc - tiering/io_mgr.cc tiering/external_alloc.cc) + tiering/external_alloc.cc) add_executable(dfly_bench dfly_bench.cc) cxx_link(dfly_bench dfly_facade fibers2 absl::random_random) diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 3a58e77fb072..d7b18fe0b359 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -64,8 +64,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { for (const auto& [username, user] : registry) { std::string buffer = "user "; - const std::string_view pass = user.Password(); - const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass); + const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), false); const std::string acl_keys = AclKeysToString(user.Keys()); const std::string maybe_space_com = acl_keys.empty() ? "" : " "; @@ -75,7 +74,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { using namespace std::string_view_literals; - absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ", + absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, acl_keys, maybe_space_com, acl_cat_and_commands); cntx->SendSimpleString(buffer); @@ -196,9 +195,7 @@ std::string AclFamily::RegistryToString() const { std::string result; for (auto& [username, user] : registry) { std::string command = "USER "; - const std::string_view pass = user.Password(); - const std::string password = - pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " "); + const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), true); const std::string acl_keys = AclKeysToString(user.Keys()); const std::string maybe_space = acl_keys.empty() ? "" : " "; @@ -495,7 +492,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } auto& user = registry.find(username)->second; std::string status = user.IsActive() ? "on" : "off"; - auto pass = user.Password(); + auto pass = PasswordsToString(user.Passwords(), user.HasNopass(), false); + if (!pass.empty()) { + pass.pop_back(); + } auto* rb = static_cast(cntx->reply_builder()); rb->StartArray(8); @@ -509,7 +509,7 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } rb->SendSimpleString("passwords"); - if (pass != "nopass") { + if (pass != "nopass" && !pass.empty()) { rb->SendSimpleString(pass); } else { rb->SendEmptyArray(); @@ -647,7 +647,7 @@ void AclFamily::Init(facade::Listener* main_listener, UserRegistry* registry) { registry_ = registry; config_registry.RegisterMutable("requirepass", [this](const absl::CommandLineFlag& flag) { User::UpdateRequest rqst; - rqst.password = flag.CurrentValue(); + rqst.passwords.push_back({flag.CurrentValue()}); registry_->MaybeAddAndUpdate("default", std::move(rqst)); return true; }); diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 0aac28dd3f1e..cbee4061fe18 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -47,16 +47,67 @@ TEST_F(AclFamilyTest, AclSetUser) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); - EXPECT_THAT( - vec, UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off nopass -@all")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all")); resp = Run({"ACL", "SETUSER", "vlad", "+ACL"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", - "user vlad off nopass -@all +acl")); + EXPECT_THAT(vec, + UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all +acl")); + + resp = Run({"ACL", "SETUSER", "vlad", "on", ">pass", ">temp"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "LIST"}); + vec = resp.GetVec(); + EXPECT_THAT(vec.size(), 2); + auto contains_vlad = [](const auto& vec) { + const std::string default_user = "user default on nopass ~* +@all"; + const std::string a_permutation = "user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 -@all +acl"; + const std::string b_permutation = "user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f -@all +acl"; + std::string_view other; + if (vec[0] == default_user) { + other = vec[1].GetView(); + } else if (vec[1] == default_user) { + other = vec[0].GetView(); + } else { + return false; + } + + return other == a_permutation || other == b_permutation; + }; + + EXPECT_THAT(contains_vlad(vec), true); + + resp = Run({"AUTH", "vlad", "pass"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"AUTH", "vlad", "temp"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"AUTH", "default", R"("")"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "SETUSER", "vlad", ">another"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "SETUSER", "vlad", " MaybeParseAclKey(std::string_view command) { return ParseKeyResult{std::string(key), op}; } -std::optional MaybeParsePassword(std::string_view command, bool hashed) { +std::optional MaybeParsePassword(std::string_view command, bool hashed) { + using UpPass = User::UpdatePass; if (command == "nopass") { - return std::string(command); + return UpPass{"", false, true}; + } + + if (command == "resetpass") { + return UpPass{"", false, false, true}; } if (command[0] == '>' || (hashed && command[0] == '#')) { - return std::string(command.substr(1)); + return UpPass{std::string(command.substr(1))}; + } + + if (command[0] == '<') { + return UpPass{std::string(command.substr(1)), true}; } return {}; @@ -261,10 +270,8 @@ std::variant ParseAclSetUser(facade::ArgRange a for (std::string_view arg : args) { if (auto pass = MaybeParsePassword(facade::ToSV(arg), hashed); pass) { - if (req.password) { - return ErrorReply("Only one password is allowed"); - } - req.password = std::move(pass); + req.passwords.push_back(std::move(*pass)); + if (hashed && absl::StartsWith(facade::ToSV(arg), "#")) { req.is_hashed = hashed; } @@ -346,4 +353,16 @@ std::string AclKeysToString(const AclKeys& keys) { return result; } +std::string PasswordsToString(const absl::flat_hash_set& passwords, bool nopass, + bool full_sha) { + if (nopass) { + return "nopass "; + } + std::string result; + for (const auto& pass : passwords) { + absl::StrAppend(&result, "#", PrettyPrintSha(pass, full_sha), " "); + } + + return result; +} } // namespace dfly::acl diff --git a/src/server/acl/helpers.h b/src/server/acl/helpers.h index 0840ab817e70..75cbd4d8b491 100644 --- a/src/server/acl/helpers.h +++ b/src/server/acl/helpers.h @@ -10,6 +10,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "facade/facade_types.h" #include "server/acl/acl_log.h" #include "server/acl/user.h" @@ -23,7 +24,7 @@ std::string AclCatAndCommandToString(const User::CategoryChanges& cat, std::string PrettyPrintSha(std::string_view pass, bool all = false); // When hashed is true, we allow passwords that start with both # and > -std::optional MaybeParsePassword(std::string_view command, bool hashed = false); +std::optional MaybeParsePassword(std::string_view command, bool hashed = false); std::optional MaybeParseStatus(std::string_view command); @@ -55,4 +56,8 @@ struct ParseKeyResult { std::optional MaybeParseAclKey(std::string_view command); std::string AclKeysToString(const AclKeys& keys); + +std::string PasswordsToString(const absl::flat_hash_set& passwords, bool nopass, + bool full_sha); + } // namespace dfly::acl diff --git a/src/server/acl/user.cc b/src/server/acl/user.cc index 863d9857f81e..48a3be8b326b 100644 --- a/src/server/acl/user.cc +++ b/src/server/acl/user.cc @@ -8,6 +8,7 @@ #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/escaping.h" #include "core/overloaded.h" #include "server/acl/helpers.h" @@ -30,8 +31,20 @@ User::User() { } void User::Update(UpdateRequest&& req) { - if (req.password) { - SetPasswordHash(*req.password, req.is_hashed); + for (auto& pass : req.passwords) { + if (pass.nopass) { + SetNopass(); + continue; + } + if (pass.unset) { + UnsetPassword(pass.password); + continue; + } + if (pass.reset_password) { + password_hashes_.clear(); + continue; + } + SetPasswordHash(pass.password, req.is_hashed); } auto cat_visitor = [this](UpdateRequest::CategoryValueType cat) { @@ -68,23 +81,23 @@ void User::Update(UpdateRequest&& req) { } void User::SetPasswordHash(std::string_view password, bool is_hashed) { - if (password == "nopass") { - return; - } - + nopass_ = false; if (is_hashed) { - password_hash_ = absl::HexStringToBytes(password); + password_hashes_.insert(absl::HexStringToBytes(password)); return; } - password_hash_ = StringSHA256(password); + password_hashes_.insert(StringSHA256(password)); +} + +void User::UnsetPassword(std::string_view password) { + password_hashes_.erase(StringSHA256(password)); } bool User::HasPassword(std::string_view password) const { - if (!password_hash_) { + if (nopass_) { return true; } - // hash password and compare - return *password_hash_ == StringSHA256(password); + return password_hashes_.contains(StringSHA256(password)); } void User::SetAclCategoriesAndIncrSeq(uint32_t cat) { @@ -174,10 +187,12 @@ bool User::IsActive() const { return is_active_; } -static const std::string_view default_pass = "nopass"; +const absl::flat_hash_set& User::Passwords() const { + return password_hashes_; +} -std::string_view User::Password() const { - return password_hash_ ? *password_hash_ : default_pass; +bool User::HasNopass() const { + return nopass_; } const AclKeys& User::Keys() const { @@ -206,4 +221,9 @@ void User::SetKeyGlobs(std::vector keys) { } } +void User::SetNopass() { + nopass_ = true; + password_hashes_.clear(); +} + } // namespace dfly::acl diff --git a/src/server/acl/user.h b/src/server/acl/user.h index fd3e84a3ff76..3e66491f08c9 100644 --- a/src/server/acl/user.h +++ b/src/server/acl/user.h @@ -14,6 +14,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "server/acl/acl_commands_def.h" @@ -30,8 +31,16 @@ class User final { bool reset_keys = false; }; + struct UpdatePass { + std::string password; + // Set to denote remove password + bool unset{false}; + bool nopass{false}; + bool reset_password{false}; + }; + struct UpdateRequest { - std::optional password{}; + std::vector passwords; std::optional is_active{}; @@ -48,6 +57,8 @@ class User final { std::vector keys; bool reset_all_keys{false}; bool allow_all_keys{false}; + // TODO allow reset all + // bool reset_all{false}; }; using CategoryChange = uint32_t; @@ -80,7 +91,9 @@ class User final { bool IsActive() const; - std::string_view Password() const; + const absl::flat_hash_set& Passwords() const; + + bool HasNopass() const; // Selector maps a command string (like HSET, SET etc) to // its respective ID within the commands vector. @@ -111,13 +124,19 @@ class User final { // For passwords void SetPasswordHash(std::string_view password, bool is_hashed); + void UnsetPassword(std::string_view password); // For ACL key globs void SetKeyGlobs(std::vector keys); - // when optional is empty, the special `nopass` password is implied - // password hashed with xx64 - std::optional password_hash_; + // Set NOPASS and remove all passwords + void SetNopass(); + + // Passwords for each user + absl::flat_hash_set password_hashes_; + // if `nopass` is used + bool nopass_ = false; + uint32_t acl_categories_{NONE}; // Each element index in the vector corresponds to a familly of commands // Each bit in the uin64_t field at index id, corresponds to a specific diff --git a/src/server/acl/user_registry.cc b/src/server/acl/user_registry.cc index 54510344e86c..9bd9645bff6c 100644 --- a/src/server/acl/user_registry.cc +++ b/src/server/acl/user_registry.cc @@ -75,7 +75,8 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock acl{User::Sign::PLUS, acl::ALL}; auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}; - return {{}, true, false, {std::move(acl)}, {std::move(key)}}; + auto pass = std::vector{{"", false, true}}; + return {std::move(pass), true, false, {std::move(acl)}, {std::move(key)}}; } void UserRegistry::Init() { @@ -86,11 +87,14 @@ void UserRegistry::Init() { auto default_user = DefaultUserUpdateRequest(); auto maybe_password = absl::GetFlag(FLAGS_requirepass); if (!maybe_password.empty()) { - default_user.password = std::move(maybe_password); + default_user.passwords.front().password = std::move(maybe_password); + default_user.passwords.front().nopass = false; } else if (const char* env_var = getenv("DFLY_PASSWORD"); env_var) { - default_user.password = env_var; + default_user.passwords.front().password = env_var; + default_user.passwords.front().nopass = false; } else if (const char* env_var = getenv("DFLY_requirepass"); env_var) { - default_user.password = env_var; + default_user.passwords.front().password = env_var; + default_user.passwords.front().nopass = false; } MaybeAddAndUpdate("default", std::move(default_user)); } diff --git a/src/server/cluster/cluster_family.cc b/src/server/cluster/cluster_family.cc index a69525c81df3..6d4e021abdf9 100644 --- a/src/server/cluster/cluster_family.cc +++ b/src/server/cluster/cluster_family.cc @@ -687,7 +687,7 @@ void ClusterFamily::DflySlotMigrationStatus(CmdArgList args, ConnectionContext* if (filter.empty() || filter == node_id) { error = error.empty() ? "0" : error; reply.push_back(absl::StrCat(direction, " ", node_id, " ", StateToStr(state), - " keys:", keys_number, " errors: ", error)); + " keys:", keys_number, " errors:", error)); } }; @@ -791,15 +791,11 @@ bool RemoveIncomingMigrationImpl(std::vector(removed.ToSlotRanges()); + auto removed_ranges = removed.ToSlotRanges(); LOG_IF(WARNING, migration->GetState() == MigrationState::C_FINISHED) << "Flushing slots of removed FINISHED migration " << migration->GetSourceID() - << ", slots: " << SlotRange::ToString(*removed_ranges); - shard_set->pool()->DispatchOnAll([removed_ranges](unsigned, ProactorBase*) { - if (EngineShard* shard = EngineShard::tlocal(); shard) { - shard->db_slice().FlushSlots(*removed_ranges); - } - }); + << ", slots: " << SlotRange::ToString(removed_ranges); + DeleteSlots(removed_ranges); } return true; @@ -844,7 +840,7 @@ void ClusterFamily::InitMigration(CmdArgList args, ConnectionContext* cntx) { lock_guard lk(migration_mu_); auto was_removed = RemoveIncomingMigrationImpl(incoming_migrations_jobs_, source_id); - LOG_IF(WARNING, was_removed) << "Reinit was happen for migration from:" << source_id; + LOG_IF(WARNING, was_removed) << "Reinit issued for migration from:" << source_id; incoming_migrations_jobs_.emplace_back(make_shared( std::move(source_id), &server_family_->service(), std::move(slots), flows_num)); diff --git a/src/server/cluster/outgoing_slot_migration.cc b/src/server/cluster/outgoing_slot_migration.cc index 7261138a67d7..3920f6977905 100644 --- a/src/server/cluster/outgoing_slot_migration.cc +++ b/src/server/cluster/outgoing_slot_migration.cc @@ -58,17 +58,11 @@ class OutgoingMigration::SliceSlotMigration : private ProtocolClient { return; } - // Check if migration was cancelled while we yielded so far. - if (cancelled_) { - return; - } - streamer_.Start(Sock()); } void Cancel() { streamer_.Cancel(); - cancelled_ = true; } void Finalize() { @@ -81,7 +75,6 @@ class OutgoingMigration::SliceSlotMigration : private ProtocolClient { private: RestoreStreamer streamer_; - bool cancelled_ = false; }; OutgoingMigration::OutgoingMigration(MigrationInfo info, ClusterFamily* cf, ServerFamily* sf) @@ -94,6 +87,13 @@ OutgoingMigration::OutgoingMigration(MigrationInfo info, ClusterFamily* cf, Serv OutgoingMigration::~OutgoingMigration() { main_sync_fb_.JoinIfNeeded(); + + // Destroy each flow in its dedicated thread, because we could be the last owner of the db tables + shard_set->pool()->AwaitFiberOnAll([this](util::ProactorBase* pb) { + if (const auto* shard = EngineShard::tlocal(); shard) { + slot_migrations_[shard->shard_id()].reset(); + } + }); } bool OutgoingMigration::ChangeState(MigrationState new_state) { diff --git a/src/server/common.h b/src/server/common.h index e850d1c41d31..d8d3b2d3123a 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -160,11 +160,13 @@ template std::string GetRandomHex(RandGen& gen, size_t len) { // truthy value; template struct AggregateValue { bool operator=(T val) { + if (!bool(val)) + return false; + std::lock_guard l{mu_}; - if (!bool(current_) && bool(val)) { + if (!bool(current_)) current_ = val; - } - return bool(val); + return true; } T operator*() { diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 901b2c9996e8..6d5aed56586b 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -458,9 +458,7 @@ OpResult DbSlice::FindInternal(const Context& cntx, std: if (!change_cb_.empty()) { auto bump_cb = [&](PrimeTable::bucket_iterator bit) { DVLOG(2) << "Running callbacks for key " << key << " in dbid " << cntx.db_index; - for (const auto& ccb : change_cb_) { - ccb.second(cntx.db_index, bit); - } + CallChangeCallbacks(cntx.db_index, bit); }; db.prime.CVCUponBump(change_cb_.back().first, res.it, bump_cb); } @@ -524,9 +522,7 @@ OpResult DbSlice::AddOrFindInternal(const Context& cnt // It's a new entry. DVLOG(2) << "Running callbacks for key " << key << " in dbid " << cntx.db_index; - for (const auto& ccb : change_cb_) { - ccb.second(cntx.db_index, key); - } + CallChangeCallbacks(cntx.db_index, key); // In case we are loading from rdb file or replicating we want to disable conservative memory // checks (inside PrimeEvictionPolicy::CanGrow) and reject insertions only after we pass max @@ -975,9 +971,7 @@ void DbSlice::PreUpdate(DbIndex db_ind, Iterator it, std::string_view key) { FiberAtomicGuard fg; DVLOG(2) << "Running callbacks in dbid " << db_ind; - for (const auto& ccb : change_cb_) { - ccb.second(db_ind, ChangeReq{it.GetInnerIt()}); - } + CallChangeCallbacks(db_ind, ChangeReq{it.GetInnerIt()}); // If the value has a pending stash, cancel it before any modification are applied. // Note: we don't delete offloaded values before updates, because a read-modify operation (like @@ -1089,6 +1083,13 @@ void DbSlice::ExpireAllIfNeeded() { uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) { uint64_t ver = NextVersion(); + + // TODO rewrite this logic to be more clear + // this mutex lock is needed to check that this method is not called simultaneously with + // change_cb_ calls and journal_slice::change_cb_arr_ calls. + // It can be unlocked anytime because DbSlice::RegisterOnChange + // and journal_slice::RegisterOnChange calls without preemption + std::lock_guard lk(cb_mu_); change_cb_.emplace_back(ver, std::move(cb)); return ver; } @@ -1099,6 +1100,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_ // change_cb_ is ordered by version. DVLOG(2) << "Running callbacks in dbid " << db_ind << " with bucket_version=" << bucket_version << ", upper_bound=" << upper_bound; + for (const auto& ccb : change_cb_) { uint64_t cb_version = ccb.first; DCHECK_LE(cb_version, upper_bound); @@ -1113,6 +1115,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_ //! Unregisters the callback. void DbSlice::UnregisterOnChange(uint64_t id) { + lock_guard lk(cb_mu_); // we need to wait until callback is finished before remove it for (auto it = change_cb_.begin(); it != change_cb_.end(); ++it) { if (it->first == id) { change_cb_.erase(it); @@ -1506,4 +1509,10 @@ void DbSlice::OnCbFinish() { fetched_items_.clear(); } +void DbSlice::CallChangeCallbacks(DbIndex id, const ChangeReq& cr) const { + for (const auto& ccb : change_cb_) { + ccb.second(id, cr); + } +} + } // namespace dfly diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 5da9ce571617..6e5184a67b3c 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -469,6 +469,14 @@ class DbSlice { void PerformDeletion(Iterator del_it, DbTable* table); void PerformDeletion(PrimeIterator del_it, DbTable* table); + void LockChangeCb() const { + return cb_mu_.lock_shared(); + } + + void UnlockChangeCb() const { + return cb_mu_.unlock_shared(); + } + private: void PreUpdate(DbIndex db_ind, Iterator it, std::string_view key); void PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size_t orig_size); @@ -523,6 +531,8 @@ class DbSlice { return version_++; } + void CallChangeCallbacks(DbIndex id, const ChangeReq& cr) const; + private: ShardId shard_id_; uint8_t caching_mode_ : 1; @@ -544,6 +554,12 @@ class DbSlice { // Used in temporary computations in Acquire/Release. mutable absl::flat_hash_set uniq_fps_; + // To ensure correct data replication, we must serialize the buckets that each running command + // will modify, followed by serializing the command to the journal. We use a mutex to prevent + // interleaving between bucket and journal registrations, and the command execution with its + // journaling. LockChangeCb is called before the callback, and UnlockChangeCb is called after + // journaling is completed. Register to bucket and journal changes is also does without preemption + mutable util::fb2::SharedMutex cb_mu_; // ordered from the smallest to largest version. std::vector> change_cb_; diff --git a/src/server/detail/save_stages_controller.cc b/src/server/detail/save_stages_controller.cc index fc823074ba46..0ef993c4ce6f 100644 --- a/src/server/detail/save_stages_controller.cc +++ b/src/server/detail/save_stages_controller.cc @@ -254,7 +254,10 @@ void SaveStagesController::SaveDfs() { // Save shard files. auto cb = [this](Transaction* t, EngineShard* shard) { + // a hack to avoid deadlock in Transaction::RunCallback(...) + shard->db_slice().UnlockChangeCb(); SaveDfsSingle(shard); + shard->db_slice().LockChangeCb(); return OpStatus::OK; }; trans_->ScheduleSingleHop(std::move(cb)); @@ -294,7 +297,10 @@ void SaveStagesController::SaveRdb() { } auto cb = [snapshot = snapshot.get()](Transaction* t, EngineShard* shard) { + // a hack to avoid deadlock in Transaction::RunCallback(...) + shard->db_slice().UnlockChangeCb(); snapshot->StartInShard(shard); + shard->db_slice().LockChangeCb(); return OpStatus::OK; }; trans_->ScheduleSingleHop(std::move(cb)); diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index b026a830af49..154304ce6f19 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -4,6 +4,7 @@ #include "server/generic_family.h" +#include #include #include "facade/reply_builder.h" @@ -44,29 +45,33 @@ namespace { constexpr size_t DUMP_FOOTER_SIZE = sizeof(uint64_t) + sizeof(uint16_t); // version number and crc -bool VerifyFooter(std::string_view msg, int* rdb_version) { +std::optional GetRdbVersion(std::string_view msg) { if (msg.size() <= DUMP_FOOTER_SIZE) { LOG(WARNING) << "got restore payload that is too short - " << msg.size(); - return false; + return std::nullopt; } - const uint8_t* footer = - reinterpret_cast(msg.data()) + (msg.size() - DUMP_FOOTER_SIZE); - uint16_t version = (*(footer + 1) << 8 | (*footer)); - *rdb_version = version; + + const std::uint8_t* footer = + reinterpret_cast(msg.data()) + (msg.size() - DUMP_FOOTER_SIZE); + const RdbVersion version = (*(footer + 1) << 8 | (*footer)); + if (version > RDB_VERSION) { LOG(WARNING) << "got restore payload with illegal version - supporting version up to " << RDB_VERSION << " got version " << version; - return false; + return std::nullopt; } + uint64_t expected_cs = crc64(0, reinterpret_cast(msg.data()), msg.size() - sizeof(uint64_t)); uint64_t actual_cs = absl::little_endian::Load64(footer + sizeof(version)); + if (actual_cs != expected_cs) { LOG(WARNING) << "CRC check failed for restore command, expecting: " << expected_cs << " got " << actual_cs; - return false; + return std::nullopt; } - return true; + + return version; } class InMemSource : public ::io::Source { @@ -96,14 +101,57 @@ ::io::Result InMemSource::ReadSome(const iovec* v, uint32_t len) { return read_total; } +class RestoreArgs { + private: + static constexpr time_t NO_EXPIRATION = 0; + + time_t expiration_ = NO_EXPIRATION; + bool abs_time_ = false; + bool replace_ = false; // if true, over-ride existing key + bool sticky_ = false; + + public: + RestoreArgs() = default; + + RestoreArgs(time_t expiration, bool abs_time, bool replace) + : expiration_(expiration), abs_time_(abs_time), replace_(replace) { + } + + constexpr bool Replace() const { + return replace_; + } + + constexpr bool Sticky() const { + return sticky_; + } + + uint64_t ExpirationTime() const { + DCHECK_GE(expiration_, 0); + return expiration_; + } + + [[nodiscard]] constexpr bool Expired() const { + return expiration_ < 0; + } + + [[nodiscard]] constexpr bool HasExpiration() const { + return expiration_ != NO_EXPIRATION; + } + + [[nodiscard]] bool UpdateExpiration(int64_t now_msec); + + static OpResult TryFrom(const CmdArgList& args); +}; + class RdbRestoreValue : protected RdbLoaderBase { public: - RdbRestoreValue(int rdb_version) { + RdbRestoreValue(RdbVersion rdb_version) { rdb_version_ = rdb_version; } - bool Add(std::string_view payload, std::string_view key, DbSlice& db_slice, DbIndex index, - uint64_t expire_ms); + std::optional Add(std::string_view payload, std::string_view key, + DbSlice& db_slice, DbIndex index, + const RestoreArgs& args); private: std::optional Parse(std::string_view payload); @@ -127,53 +175,29 @@ std::optional RdbRestoreValue::Parse(std::string_view } } -bool RdbRestoreValue::Add(std::string_view data, std::string_view key, DbSlice& db_slice, - DbIndex index, uint64_t expire_ms) { +std::optional RdbRestoreValue::Add(std::string_view data, + std::string_view key, DbSlice& db_slice, + DbIndex index, const RestoreArgs& args) { auto opaque_res = Parse(data); if (!opaque_res) { - return false; + return std::nullopt; } PrimeValue pv; if (auto ec = FromOpaque(*opaque_res, &pv); ec) { // we failed - report and exit LOG(WARNING) << "error while trying to save data: " << ec; - return false; - } - - auto res = db_slice.AddNew(DbContext{index, GetCurrentTimeMs()}, key, std::move(pv), expire_ms); - return res.ok(); -} - -class RestoreArgs { - static constexpr int64_t NO_EXPIRATION = 0; - - int64_t expiration_ = NO_EXPIRATION; - bool abs_time_ = false; - bool replace_ = false; // if true, over-ride existing key - - public: - constexpr bool Replace() const { - return replace_; - } - - uint64_t ExpirationTime() const { - DCHECK_GE(expiration_, 0); - return expiration_; - } - - [[nodiscard]] constexpr bool Expired() const { - return expiration_ < 0; + return std::nullopt; } - [[nodiscard]] constexpr bool HasExpiration() const { - return expiration_ != NO_EXPIRATION; + auto res = db_slice.AddNew(DbContext{index, GetCurrentTimeMs()}, key, std::move(pv), + args.ExpirationTime()); + res->it->first.SetSticky(args.Sticky()); + if (res) { + return std::move(res.value()); } - - [[nodiscard]] bool UpdateExpiration(int64_t now_msec); - - static OpResult TryFrom(const CmdArgList& args); -}; + return std::nullopt; +} [[nodiscard]] bool RestoreArgs::UpdateExpiration(int64_t now_msec) { if (HasExpiration()) { @@ -214,6 +238,8 @@ OpResult RestoreArgs::TryFrom(const CmdArgList& args) { out_args.replace_ = true; } else if (cur_arg == "ABSTTL") { out_args.abs_time_ = true; + } else if (cur_arg == "STICK") { + out_args.sticky_ = true; } else if (cur_arg == "IDLETIME" && additional) { ++i; cur_arg = ArgS(args, i); @@ -245,159 +271,200 @@ OpStatus OpPersist(const OpArgs& op_args, string_view key); class Renamer { public: - Renamer(ShardId source_id) : src_sid_(source_id) { + Renamer(Transaction* t, std::string_view src_key, std::string_view dest_key, unsigned shard_count) + : transaction_(t), + src_key_(src_key), + dest_key_(dest_key), + src_sid_(Shard(src_key, shard_count)), + dest_sid_(Shard(dest_key, shard_count)) { } - void Find(Transaction* t); - - OpResult status() const { - return status_; - }; - - void Finalize(Transaction* t, bool skip_exist_dest); + ErrorReply Rename(bool destination_should_not_exist); private: - OpStatus MoveSrc(Transaction* t, EngineShard* es); - OpStatus UpdateDest(Transaction* t, EngineShard* es); + void FetchData(); + void FinalizeRename(); + + bool KeyExists(Transaction* t, EngineShard* shard, std::string_view key) const; + void SerializeSrc(Transaction* t, EngineShard* shard); - ShardId src_sid_; + OpStatus DelSrc(Transaction* t, EngineShard* shard); + OpStatus DeserializeDest(Transaction* t, EngineShard* shard); - struct FindResult { - string_view key; - PrimeValue ref_val; - uint64_t expire_ts; + struct SerializedValue { + std::string value; + std::optional version; + time_t expire_ts; bool sticky; - bool found = false; }; - PrimeValue pv_; - string str_val_; + private: + Transaction* const transaction_; + + const std::string_view src_key_; + const std::string_view dest_key_; + const ShardId src_sid_; + const ShardId dest_sid_; + + bool src_found_ = false; + bool dest_found_ = false; - FindResult src_res_, dest_res_; // index 0 for source, 1 for destination - OpResult status_; + SerializedValue serialized_value_; }; -void Renamer::Find(Transaction* t) { +ErrorReply Renamer::Rename(bool destination_should_not_exist) { + FetchData(); + + if (!src_found_) { + transaction_->Conclude(); + return OpStatus::KEY_NOTFOUND; + } + + if (!serialized_value_.version) { + transaction_->Conclude(); + return ErrorReply{kInvalidDumpValueErr}; + } + + if (dest_found_ && destination_should_not_exist) { + transaction_->Conclude(); + return OpStatus::KEY_EXISTS; + } + + FinalizeRename(); + return OpStatus::OK; +} + +void Renamer::FetchData() { auto cb = [this](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); DCHECK_EQ(1u, args.Size()); - FindResult* res = (shard->shard_id() == src_sid_) ? &src_res_ : &dest_res_; + const ShardId shard_id = shard->shard_id(); - res->key = args.Front(); - auto& db_slice = EngineShard::tlocal()->db_slice(); - auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), res->key); + if (shard_id == src_sid_) { + SerializeSrc(t, shard); + } - res->found = IsValid(it); - if (res->found) { - res->ref_val = it->second.AsRef(); - res->expire_ts = db_slice.ExpireTime(exp_it); - res->sticky = it->first.IsSticky(); + if (shard_id == dest_sid_) { + dest_found_ = KeyExists(t, shard, dest_key_); } + return OpStatus::OK; }; - t->Execute(std::move(cb), false); -}; + transaction_->Execute(std::move(cb), false); +} -void Renamer::Finalize(Transaction* t, bool skip_exist_dest) { - if (!src_res_.found) { - status_ = OpStatus::KEY_NOTFOUND; - t->Conclude(); - return; - } +void Renamer::FinalizeRename() { + auto cb = [this](Transaction* t, EngineShard* shard) { + const ShardId shard_id = shard->shard_id(); + + if (shard_id == src_sid_) { + return DelSrc(t, shard); + } + + if (shard_id == dest_sid_) { + return DeserializeDest(t, shard); + } + + return OpStatus::OK; + }; + + transaction_->Execute(std::move(cb), true); +} + +bool Renamer::KeyExists(Transaction* t, EngineShard* shard, std::string_view key) const { + auto& db_slice = shard->db_slice(); + auto it = db_slice.FindReadOnly(t->GetDbContext(), key).it; + return IsValid(it); +} - if (dest_res_.found && skip_exist_dest) { - status_ = OpStatus::KEY_EXISTS; - t->Conclude(); +void Renamer::SerializeSrc(Transaction* t, EngineShard* shard) { + auto& db_slice = shard->db_slice(); + auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), src_key_); + + src_found_ = IsValid(it); + if (!src_found_) { return; } - DCHECK(src_res_.ref_val.IsRef()); + DVLOG(1) << "Rename: key '" << src_key_ << "' successfully found, going to dump it"; + + io::StringSink sink; + SerializerBase::DumpObject(it->second, &sink); - // Src key exist and we need to override the destination. - // Alternatively, we could apply an optimistic algorithm and move src at Find step. - // We would need to restore the state in case of cleanups. - t->Execute([&](Transaction* t, EngineShard* shard) { return MoveSrc(t, shard); }, false); - t->Execute([&](Transaction* t, EngineShard* shard) { return UpdateDest(t, shard); }, true); + auto rdb_version = GetRdbVersion(sink.str()); + serialized_value_ = {std::move(sink).str(), rdb_version, db_slice.ExpireTime(exp_it), + it->first.IsSticky()}; } -OpStatus Renamer::MoveSrc(Transaction* t, EngineShard* es) { - if (es->shard_id() == src_sid_) { // Handle source key. - auto res = es->db_slice().FindMutable(t->GetDbContext(), src_res_.key); - auto& it = res.it; - CHECK(IsValid(it)); - - // We distinguish because of the SmallString that is pinned to its thread by design, - // thus can not be accessed via another thread. - // Therefore, we copy it to standard string in its thread. - if (it->second.ObjType() == OBJ_STRING) { - it->second.GetString(&str_val_); - } else { - bool has_expire = it->second.HasExpire(); - pv_ = std::move(it->second); - it->second.SetExpire(has_expire); - } +OpStatus Renamer::DelSrc(Transaction* t, EngineShard* shard) { + auto res = shard->db_slice().FindMutable(t->GetDbContext(), src_key_); + auto& it = res.it; - res.post_updater.Run(); - CHECK(es->db_slice().Del(t->GetDbIndex(), it)); // delete the entry with empty value in it. - if (es->journal()) { - RecordJournal(t->GetOpArgs(es), "DEL", ArgSlice{src_res_.key}, 2); - } + CHECK(IsValid(it)); + + DVLOG(1) << "Rename: removing the key '" << src_key_; + + res.post_updater.Run(); + CHECK(shard->db_slice().Del(t->GetDbIndex(), it)); + if (shard->journal()) { + RecordJournal(t->GetOpArgs(shard), "DEL"sv, ArgSlice{src_key_}, 2); } return OpStatus::OK; } -OpStatus Renamer::UpdateDest(Transaction* t, EngineShard* es) { - if (es->shard_id() != src_sid_) { - auto& db_slice = es->db_slice(); - string_view dest_key = dest_res_.key; - auto res = db_slice.FindMutable(t->GetDbContext(), dest_key); - auto& dest_it = res.it; - bool is_prior_list = false; - - if (IsValid(dest_it)) { - bool has_expire = dest_it->second.HasExpire(); - is_prior_list = dest_it->second.ObjType() == OBJ_LIST; - - if (src_res_.ref_val.ObjType() == OBJ_STRING) { - dest_it->second.SetString(str_val_); - } else { - dest_it->second = std::move(pv_); - } - dest_it->second.SetExpire(has_expire); // preserve expire flag. - db_slice.UpdateExpire(t->GetDbIndex(), dest_it, src_res_.expire_ts); - } else { - if (src_res_.ref_val.ObjType() == OBJ_STRING) { - pv_.SetString(str_val_); - } - auto op_res = - db_slice.AddNew(t->GetDbContext(), dest_key, std::move(pv_), src_res_.expire_ts); - RETURN_ON_BAD_STATUS(op_res); - res = std::move(*op_res); +OpStatus Renamer::DeserializeDest(Transaction* t, EngineShard* shard) { + OpArgs op_args = t->GetOpArgs(shard); + RestoreArgs restore_args{serialized_value_.expire_ts, true, true}; + + if (!restore_args.UpdateExpiration(op_args.db_cntx.time_now_ms)) { + return OpStatus::OUT_OF_RANGE; + } + + auto& db_slice = shard->db_slice(); + auto dest_res = db_slice.FindMutable(op_args.db_cntx, dest_key_); + + if (dest_found_) { + DVLOG(1) << "Rename: deleting the destiny key '" << dest_key_; + dest_res.post_updater.Run(); + CHECK(db_slice.Del(op_args.db_cntx.db_index, dest_res.it)); + } + + if (restore_args.Expired()) { + VLOG(1) << "Rename: the new key '" << dest_key_ << "' already expired, will not save the value"; + + if (dest_found_ && shard->journal()) { // We need to delete old dest_key_ from replica + RecordJournal(op_args, "DEL"sv, ArgSlice{dest_key_}, 2); } - dest_it->first.SetSticky(src_res_.sticky); + return OpStatus::OK; + } + + RdbRestoreValue loader(serialized_value_.version.value()); + auto restored_dest_it = loader.Add(serialized_value_.value, dest_key_, db_slice, + op_args.db_cntx.db_index, restore_args); + + if (restored_dest_it) { + auto& dest_it = restored_dest_it->it; + dest_it->first.SetSticky(serialized_value_.sticky); - if (!is_prior_list && dest_it->second.ObjType() == OBJ_LIST && es->blocking_controller()) { - es->blocking_controller()->AwakeWatched(t->GetDbIndex(), dest_key); + auto bc = shard->blocking_controller(); + if (bc) { + bc->AwakeWatched(t->GetDbIndex(), dest_key_); } - if (es->journal()) { - OpArgs op_args = t->GetOpArgs(es); - string scratch; - // todo insert under multi exec - RecordJournal(op_args, "SET"sv, ArgSlice{dest_key, dest_it->second.GetSlice(&scratch)}, 2, - true); - if (dest_it->first.IsSticky()) { - RecordJournal(op_args, "STICK"sv, ArgSlice{dest_key}, 2, true); - } - if (dest_it->second.HasExpire()) { - auto time = absl::StrCat(src_res_.expire_ts); - RecordJournal(op_args, "PEXPIREAT"sv, ArgSlice{dest_key, time}, 2, true); - } - RecordJournalFinish(op_args, 2); + } + + if (shard->journal()) { + auto expire_str = absl::StrCat(serialized_value_.expire_ts); + RecordJournal(op_args, "RESTORE"sv, + ArgSlice{dest_key_, expire_str, serialized_value_.value, "REPLACE"sv, "ABSTTL"sv}, + 2, true); + if (serialized_value_.sticky) { + RecordJournal(op_args, "STICK"sv, ArgSlice{dest_key_}, 2, true); } + RecordJournalFinish(op_args, 2); } return OpStatus::OK; @@ -435,7 +502,7 @@ OpResult OpDump(const OpArgs& op_args, string_view key) { } OpResult OnRestore(const OpArgs& op_args, std::string_view key, std::string_view payload, - RestoreArgs restore_args, int rdb_version) { + RestoreArgs restore_args, RdbVersion rdb_version) { if (!restore_args.UpdateExpiration(op_args.db_cntx.time_now_ms)) { return OpStatus::OUT_OF_RANGE; } @@ -466,9 +533,8 @@ OpResult OnRestore(const OpArgs& op_args, std::string_view key, std::strin } RdbRestoreValue loader(rdb_version); - - return loader.Add(payload, key, db_slice, op_args.db_cntx.db_index, - restore_args.ExpirationTime()); + auto res = loader.Add(payload, key, db_slice, op_args.db_cntx.db_index, restore_args); + return res.has_value(); } bool ScanCb(const OpArgs& op_args, PrimeIterator prime_it, const ScanOpts& opts, string* scratch, @@ -1116,9 +1182,10 @@ void GenericFamily::Sort(CmdArgList args, ConnectionContext* cntx) { void GenericFamily::Restore(CmdArgList args, ConnectionContext* cntx) { std::string_view key = ArgS(args, 0); std::string_view serialized_value = ArgS(args, 2); - int rdb_version = 0; - if (!VerifyFooter(serialized_value, &rdb_version)) { - return cntx->SendError("ERR DUMP payload version or checksum are wrong"); + + auto rdb_version = GetRdbVersion(serialized_value); + if (!rdb_version) { + return cntx->SendError(kInvalidDumpValueErr); } OpResult restore_args = RestoreArgs::TryFrom(args); @@ -1131,7 +1198,8 @@ void GenericFamily::Restore(CmdArgList args, ConnectionContext* cntx) { } auto cb = [&](Transaction* t, EngineShard* shard) { - return OnRestore(t->GetOpArgs(shard), key, serialized_value, restore_args.value(), rdb_version); + return OnRestore(t->GetOpArgs(shard), key, serialized_value, restore_args.value(), + rdb_version.value()); }; OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); @@ -1140,7 +1208,7 @@ void GenericFamily::Restore(CmdArgList args, ConnectionContext* cntx) { if (result.value()) { return cntx->SendOk(); } else { - return cntx->SendError("Bad data format"); + return cntx->SendError("Bad data formatasdfasdf"); } } else { switch (result.status()) { @@ -1212,19 +1280,25 @@ void GenericFamily::Move(CmdArgList args, ConnectionContext* cntx) { } void GenericFamily::Rename(CmdArgList args, ConnectionContext* cntx) { - OpResult st = RenameGeneric(args, false, cntx); - cntx->SendError(st.status()); + auto reply = RenameGeneric(args, false, cntx); + cntx->SendError(reply); } void GenericFamily::RenameNx(CmdArgList args, ConnectionContext* cntx) { - OpResult st = RenameGeneric(args, true, cntx); - OpStatus status = st.status(); - if (status == OpStatus::OK) { + auto reply = RenameGeneric(args, true, cntx); + + if (!reply.status) { + cntx->SendError(reply); + return; + } + + OpStatus st = reply.status.value(); + if (st == OpStatus::OK) { cntx->SendLong(1); - } else if (status == OpStatus::KEY_EXISTS) { + } else if (st == OpStatus::KEY_EXISTS) { cntx->SendLong(0); } else { - cntx->SendError(status); + cntx->SendError(reply); } } @@ -1332,8 +1406,8 @@ void GenericFamily::Time(CmdArgList args, ConnectionContext* cntx) { rb->SendLong(now_usec % 1000000); } -OpResult GenericFamily::RenameGeneric(CmdArgList args, bool skip_exist_dest, - ConnectionContext* cntx) { +ErrorReply GenericFamily::RenameGeneric(CmdArgList args, bool destination_should_not_exist, + ConnectionContext* cntx) { string_view key[2] = {ArgS(args, 0), ArgS(args, 1)}; Transaction* transaction = cntx->transaction; @@ -1341,23 +1415,15 @@ OpResult GenericFamily::RenameGeneric(CmdArgList args, bool skip_exist_des if (transaction->GetUniqueShardCnt() == 1) { transaction->ReviveAutoJournal(); // Safe to use RENAME with single shard auto cb = [&](Transaction* t, EngineShard* shard) { - return OpRen(t->GetOpArgs(shard), key[0], key[1], skip_exist_dest); + return OpRen(t->GetOpArgs(shard), key[0], key[1], destination_should_not_exist); }; OpResult result = transaction->ScheduleSingleHopT(std::move(cb)); - return result; + return result.status(); } - unsigned shard_count = shard_set->size(); - Renamer renamer{Shard(key[0], shard_count)}; - - // Phase 1 -> Fetch keys from both shards. - // Phase 2 -> If everything is ok, clone the source object, delete the destination object, and - // set its ptr to cloned one. we also copy the expiration data of the source key. - renamer.Find(transaction); - renamer.Finalize(transaction, skip_exist_dest); - - return renamer.status(); + Renamer renamer{transaction, key[0], key[1], shard_set->size()}; + return renamer.Rename(destination_should_not_exist); } void GenericFamily::Echo(CmdArgList args, ConnectionContext* cntx) { @@ -1422,7 +1488,7 @@ OpResult GenericFamily::OpExists(const OpArgs& op_args, const ShardArg } OpResult GenericFamily::OpRen(const OpArgs& op_args, string_view from_key, string_view to_key, - bool skip_exists) { + bool destination_should_not_exist) { auto* es = op_args.shard; auto& db_slice = es->db_slice(); auto from_res = db_slice.FindMutable(op_args.db_cntx, from_key); @@ -1435,7 +1501,7 @@ OpResult GenericFamily::OpRen(const OpArgs& op_args, string_view from_key, bool is_prior_list = false; auto to_res = db_slice.FindMutable(op_args.db_cntx, to_key); if (IsValid(to_res.it)) { - if (skip_exists) + if (destination_should_not_exist) return OpStatus::KEY_EXISTS; is_prior_list = (to_res.it->second.ObjType() == OBJ_LIST); diff --git a/src/server/generic_family.h b/src/server/generic_family.h index 015ed7fcb682..dffae8a50b9a 100644 --- a/src/server/generic_family.h +++ b/src/server/generic_family.h @@ -5,7 +5,7 @@ #pragma once #include "base/flags.h" -#include "facade/op_status.h" +#include "facade/facade_types.h" #include "server/common.h" #include "server/table.h" @@ -17,6 +17,7 @@ class ProactorPool; namespace dfly { +using facade::ErrorReply; using facade::OpResult; using facade::OpStatus; @@ -71,13 +72,13 @@ class GenericFamily { static void RandomKey(CmdArgList args, ConnectionContext* cntx); static void FieldTtl(CmdArgList args, ConnectionContext* cntx); - static OpResult RenameGeneric(CmdArgList args, bool skip_exist_dest, - ConnectionContext* cntx); + static ErrorReply RenameGeneric(CmdArgList args, bool destination_should_not_exist, + ConnectionContext* cntx); static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit); static OpResult OpTtl(Transaction* t, EngineShard* shard, std::string_view key); static OpResult OpRen(const OpArgs& op_args, std::string_view from, std::string_view to, - bool skip_exists); + bool destination_should_not_exist); static OpStatus OpMove(const OpArgs& op_args, std::string_view key, DbIndex target_db); }; diff --git a/src/server/generic_family_test.cc b/src/server/generic_family_test.cc index dcd8b58e33a5..52e97668167c 100644 --- a/src/server/generic_family_test.cc +++ b/src/server/generic_family_test.cc @@ -695,10 +695,13 @@ TEST_F(GenericFamilyTest, Info) { EXPECT_EQ(1, get_rdb_changes_since_last_save(resp.GetString())); EXPECT_EQ(Run({"bgsave"}), "OK"); - WaitUntilCondition([&]() { - resp = Run({"info", "persistence"}); - return get_rdb_changes_since_last_save(resp.GetString()) == 0; - }); + bool cond = WaitUntilCondition( + [&]() { + resp = Run({"info", "persistence"}); + return get_rdb_changes_since_last_save(resp.GetString()) == 0; + }, + 500ms); + EXPECT_TRUE(cond); EXPECT_EQ(Run({"set", "k3", "3"}), "OK"); resp = Run({"info", "persistence"}); diff --git a/src/server/journal/journal_slice.cc b/src/server/journal/journal_slice.cc index f7c068f4613b..c66cfc8a00f3 100644 --- a/src/server/journal/journal_slice.cc +++ b/src/server/journal/journal_slice.cc @@ -165,6 +165,7 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) { item = &dummy; item->opcode = entry.opcode; item->lsn = lsn_++; + item->cmd = entry.payload.cmd; item->slot = entry.slot; io::BufSink buf_sink{&ring_serialize_buf_}; @@ -198,13 +199,14 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) { } uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) { - lock_guard lk(cb_mu_); + // mutex lock isn't needed due to iterators are not invalidated uint32_t id = next_cb_id_++; change_cb_arr_.emplace_back(id, std::move(cb)); return id; } void JournalSlice::UnregisterOnChange(uint32_t id) { + // we need to wait until callback is finished before remove it lock_guard lk(cb_mu_); auto it = find_if(change_cb_arr_.begin(), change_cb_arr_.end(), [id](const auto& e) { return e.first == id; }); diff --git a/src/server/journal/journal_slice.h b/src/server/journal/journal_slice.h index 2752eb463c56..8534d78f7aae 100644 --- a/src/server/journal/journal_slice.h +++ b/src/server/journal/journal_slice.h @@ -47,7 +47,6 @@ class JournalSlice { void UnregisterOnChange(uint32_t); bool HasRegisteredCallbacks() const { - std::shared_lock lk(cb_mu_); return !change_cb_arr_.empty(); } @@ -62,8 +61,8 @@ class JournalSlice { std::optional> ring_buffer_; base::IoBuf ring_serialize_buf_; - mutable util::fb2::SharedMutex cb_mu_; - std::vector> change_cb_arr_ ABSL_GUARDED_BY(cb_mu_); + mutable util::fb2::SharedMutex cb_mu_; // to prevent removing callback during call + std::list> change_cb_arr_; LSN lsn_ = 1; diff --git a/src/server/journal/streamer.cc b/src/server/journal/streamer.cc index c00c2ae877af..b6a6f98d0955 100644 --- a/src/server/journal/streamer.cc +++ b/src/server/journal/streamer.cc @@ -34,7 +34,7 @@ uint32_t replication_stream_output_limit_cached = 64_KB; } // namespace JournalStreamer::JournalStreamer(journal::Journal* journal, Context* cntx) - : journal_(journal), cntx_(cntx) { + : cntx_(cntx), journal_(journal) { // cache the flag to avoid accessing it later. replication_stream_output_limit_cached = absl::GetFlag(FLAGS_replication_stream_output_limit); } @@ -44,7 +44,7 @@ JournalStreamer::~JournalStreamer() { VLOG(1) << "~JournalStreamer"; } -void JournalStreamer::Start(io::AsyncSink* dest, bool send_lsn) { +void JournalStreamer::Start(util::FiberSocketBase* dest, bool send_lsn) { CHECK(dest_ == nullptr && dest != nullptr); dest_ = dest; journal_cb_id_ = @@ -188,9 +188,13 @@ RestoreStreamer::RestoreStreamer(DbSlice* slice, cluster::SlotSet slots, journal Context* cntx) : JournalStreamer(journal, cntx), db_slice_(slice), my_slots_(std::move(slots)) { DCHECK(slice != nullptr); + db_array_ = slice->databases(); // Inc ref to make sure DB isn't deleted while we use it } -void RestoreStreamer::Start(io::AsyncSink* dest, bool send_lsn) { +void RestoreStreamer::Start(util::FiberSocketBase* dest, bool send_lsn) { + if (fiber_cancelled_) + return; + VLOG(1) << "RestoreStreamer start"; auto db_cb = absl::bind_front(&RestoreStreamer::OnDbChange, this); snapshot_version_ = db_slice_->RegisterOnChange(std::move(db_cb)); @@ -199,7 +203,7 @@ void RestoreStreamer::Start(io::AsyncSink* dest, bool send_lsn) { PrimeTable::Cursor cursor; uint64_t last_yield = 0; - PrimeTable* pt = &db_slice_->databases()[0]->prime; + PrimeTable* pt = &db_array_[0]->prime; do { if (fiber_cancelled_) @@ -244,14 +248,22 @@ RestoreStreamer::~RestoreStreamer() { void RestoreStreamer::Cancel() { auto sver = snapshot_version_; snapshot_version_ = 0; // to prevent double cancel in another fiber + fiber_cancelled_ = true; if (sver != 0) { - fiber_cancelled_ = true; db_slice_->UnregisterOnChange(sver); JournalStreamer::Cancel(); } } bool RestoreStreamer::ShouldWrite(const journal::JournalItem& item) const { + if (item.cmd == "FLUSHALL" || item.cmd == "FLUSHDB") { + // On FLUSH* we restart the migration + CHECK(dest_ != nullptr); + cntx_->ReportError("FLUSH command during migration"); + dest_->Shutdown(SHUT_RDWR); + return false; + } + if (!item.slot.has_value()) { return false; } @@ -289,7 +301,7 @@ bool RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) { expire = db_slice_->ExpireTime(eit); } - WriteEntry(key, pv, expire); + WriteEntry(key, it->first, pv, expire); } } } @@ -314,8 +326,9 @@ void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req } } -void RestoreStreamer::WriteEntry(string_view key, const PrimeValue& pv, uint64_t expire_ms) { - absl::InlinedVector args; +void RestoreStreamer::WriteEntry(string_view key, const PrimeValue& pk, const PrimeValue& pv, + uint64_t expire_ms) { + absl::InlinedVector args; args.push_back(key); string expire_str = absl::StrCat(expire_ms); @@ -327,6 +340,10 @@ void RestoreStreamer::WriteEntry(string_view key, const PrimeValue& pv, uint64_t args.push_back("ABSTTL"); // Means expire string is since epoch + if (pk.IsSticky()) { + args.push_back("STICK"); + } + WriteCommand(journal::Entry::Payload("RESTORE", ArgSlice(args))); } diff --git a/src/server/journal/streamer.h b/src/server/journal/streamer.h index 7cb8b34bf5d4..9cef8352b3ec 100644 --- a/src/server/journal/streamer.h +++ b/src/server/journal/streamer.h @@ -23,7 +23,7 @@ class JournalStreamer { JournalStreamer(JournalStreamer&& other) = delete; // Register journal listener and start writer in fiber. - virtual void Start(io::AsyncSink* dest, bool send_lsn); + virtual void Start(util::FiberSocketBase* dest, bool send_lsn); // Must be called on context cancellation for unblocking // and manual cleanup. @@ -48,6 +48,9 @@ class JournalStreamer { void WaitForInflightToComplete(); + util::FiberSocketBase* dest_ = nullptr; + Context* cntx_; + private: void OnCompletion(std::error_code ec, size_t len); @@ -58,8 +61,6 @@ class JournalStreamer { bool IsStalled() const; journal::Journal* journal_; - Context* cntx_; - io::AsyncSink* dest_ = nullptr; std::vector pending_buf_; size_t in_flight_bytes_ = 0; time_t last_lsn_time_ = 0; @@ -74,7 +75,7 @@ class RestoreStreamer : public JournalStreamer { RestoreStreamer(DbSlice* slice, cluster::SlotSet slots, journal::Journal* journal, Context* cntx); ~RestoreStreamer() override; - void Start(io::AsyncSink* dest, bool send_lsn = false) override; + void Start(util::FiberSocketBase* dest, bool send_lsn = false) override; // Cancel() must be called if Start() is called void Cancel() override; @@ -92,10 +93,11 @@ class RestoreStreamer : public JournalStreamer { // Returns whether anything was written bool WriteBucket(PrimeTable::bucket_iterator it); - void WriteEntry(string_view key, const PrimeValue& pv, uint64_t expire_ms); + void WriteEntry(string_view key, const PrimeValue& pk, const PrimeValue& pv, uint64_t expire_ms); void WriteCommand(journal::Entry::Payload cmd_payload); DbSlice* db_slice_; + DbTableArray db_array_; uint64_t snapshot_version_ = 0; cluster::SlotSet my_slots_; bool fiber_cancelled_ = false; diff --git a/src/server/journal/types.h b/src/server/journal/types.h index aeb0286ca65f..63c35b9befc9 100644 --- a/src/server/journal/types.h +++ b/src/server/journal/types.h @@ -95,6 +95,7 @@ struct JournalItem { LSN lsn; Op opcode; std::string data; + std::string_view cmd; std::optional slot; }; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 399d25a051e8..c0d823b808d5 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1671,13 +1671,6 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) { return cntx->SendOk(); } -template void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) { - SinkReplyBuilder* old_rrb = nullptr; - old_rrb = cntx->Inject(crb); - f(); - cntx->Inject(old_rrb); -} - optional Service::FlushEvalAsyncCmds(ConnectionContext* cntx, bool force) { auto& info = cntx->conn_state.script_info; @@ -1693,9 +1686,10 @@ optional Service::FlushEvalAsyncCmds(ConnectionC cntx->transaction->MultiSwitchCmd(eval_cid); CapturingReplyBuilder crb{ReplyMode::ONLY_ERR}; - WithReplies(&crb, cntx, [&] { + { + CapturingReplyBuilder::ScopeCapture capture{&crb, cntx}; MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, this, true, true); - }); + } info->async_cmds_heap_mem = 0; info->async_cmds.clear(); diff --git a/src/server/rdb_load.h b/src/server/rdb_load.h index 2f8b8a34218b..824c3065a42a 100644 --- a/src/server/rdb_load.h +++ b/src/server/rdb_load.h @@ -25,6 +25,8 @@ class Service; class DecompressImpl; +using RdbVersion = std::uint16_t; + class RdbLoaderBase { protected: RdbLoaderBase(); @@ -170,7 +172,7 @@ class RdbLoaderBase { std::unique_ptr decompress_impl_; JournalReader journal_reader_{nullptr, 0}; std::optional journal_offset_ = std::nullopt; - int rdb_version_ = RDB_VERSION; + RdbVersion rdb_version_ = RDB_VERSION; }; class RdbLoader : protected RdbLoaderBase { diff --git a/src/server/replica.cc b/src/server/replica.cc index ac85a88adf05..d7e02fa1d497 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -38,9 +38,17 @@ ABSL_FLAG(int, master_reconnect_timeout_ms, 1000, "Timeout for re-establishing connection to a replication master"); ABSL_FLAG(bool, replica_partial_sync, true, "Use partial sync to reconnect when a replica connection is interrupted."); -ABSL_FLAG(bool, replica_reconnect_on_master_restart, false, - "When in replica mode, and master restarts, break replication from master."); +ABSL_FLAG(bool, break_replication_on_master_restart, false, + "When in replica mode, and master restarts, break replication from master to avoid " + "flushing the replica's data."); ABSL_DECLARE_FLAG(int32_t, port); +ABSL_FLAG( + int, replica_priority, 100, + "Published by info command for sentinel to pick replica based on score during a failover"); + +// TODO: Remove this flag on release >= 1.22 +ABSL_FLAG(bool, replica_reconnect_on_master_restart, false, + "Deprecated - please use --break_replication_on_master_restart."); namespace dfly { @@ -303,7 +311,8 @@ std::error_code Replica::HandleCapaDflyResp() { // If we're syncing a different replication ID, drop the saved LSNs. string_view master_repl_id = ToSV(LastResponseArgs()[0].GetBuf()); if (master_context_.master_repl_id != master_repl_id) { - if (absl::GetFlag(FLAGS_replica_reconnect_on_master_restart) && + if ((absl::GetFlag(FLAGS_replica_reconnect_on_master_restart) || + absl::GetFlag(FLAGS_break_replication_on_master_restart)) && !master_context_.master_repl_id.empty()) { LOG(ERROR) << "Encountered different master repl id (" << master_repl_id << " vs " << master_context_.master_repl_id << ")"; diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index ff34a4ce6101..d2205f316ebd 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -52,13 +52,44 @@ const absl::flat_hash_map kSchemaTy {"NUMERIC"sv, search::SchemaField::NUMERIC}, {"VECTOR"sv, search::SchemaField::VECTOR}}; +size_t GetProbabilisticBound(size_t hits, size_t requested, optional agg) { + auto intlog2 = [](size_t x) { + size_t l = 0; + while (x >>= 1) + ++l; + return l; + }; + + if (hits == 0 || requested == 0) + return 0; + + size_t shards = shard_set->size(); + + // Estimate how much every shard has with at least 99% prob + size_t avg_shard_min = hits * intlog2(hits) / (12 + shard_set->size() / 10); + avg_shard_min -= min(avg_shard_min, min(hits, size_t(5))); + + // If it turns out that we might have not enough results to cover the request, don't skip any + if (avg_shard_min * shards < requested) + return requested; + + // If all shards have at least avg min, keep the bare minimum needed to cover the request + size_t limit = requested / shards + 1; + + // Aggregations like SORTBY and KNN reorder the result and thus introduce some variance + if (agg.has_value()) + limit += max(requested / 4 + 1, 3UL); + + return limit; +} + } // namespace -bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const { +bool DocResult::operator<(const DocResult& other) const { return this->score < other.score; } -bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const { +bool DocResult::operator>=(const DocResult& other) const { return this->score >= other.score; } @@ -172,10 +203,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const { } ShardDocIndex::ShardDocIndex(shared_ptr index) - : base_{std::move(index)}, indices_{{}, nullptr}, key_index_{} { + : base_{std::move(index)}, indices_{{}, nullptr}, key_index_{}, write_epoch_{0} { } void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) { + write_epoch_++; key_index_ = DocKeyIndex{}; indices_ = search::FieldIndices{base_->schema, mr}; @@ -186,11 +218,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) } void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { + write_epoch_++; auto accessor = GetAccessor(db_cntx, pv); indices_.Add(key_index_.Add(key), accessor.get()); } void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { + write_epoch_++; auto accessor = GetAccessor(db_cntx, pv); DocId id = key_index_.Remove(key); indices_.Remove(id, accessor.get()); @@ -200,37 +234,83 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { return base_->Matches(key, obj_code); } -SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params, - search::SearchAlgorithm* search_algo) const { +io::Result ShardDocIndex::Search( + const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const { + size_t requested_count = params.limit_offset + params.limit_total; + auto search_results = search_algo->Search(&indices_, requested_count); + if (!search_results.error.empty()) + return nonstd::make_unexpected(facade::ErrorReply(std::move(search_results.error))); + + size_t return_count = min(requested_count, search_results.ids.size()); + + // Probabilistic optimization: If we are about 99% sure that all shards in total fetch more + // results than needed to statisfy the search request, we can avoid serializing some of the last + // result hits as they likely won't be needed. The `cutoff_bound` indicates how much entries it's + // reasonable to serialize directly, for the rest only id's are stored. In the 1% case they are + // either serialized on another hop or the query is fully repeated without this optimization. + size_t cuttoff_bound = requested_count; + if (params.enable_cutoff && !params.IdsOnly()) { + cuttoff_bound = GetProbabilisticBound(search_results.pre_aggregation_total, requested_count, + search_algo->HasAggregation()); + } + + vector out(return_count); + auto shard_id = EngineShard::tlocal()->shard_id(); + auto& scores = search_results.scores; + for (size_t i = 0; i < out.size(); i++) { + out[i].value = DocResult::DocReference{shard_id, search_results.ids[i], i < cuttoff_bound}; + out[i].score = scores.empty() ? search::ResultScore{} : std::move(scores[i]); + } + + Serialize(op_args, params, absl::MakeSpan(out)); + + return SearchResult{write_epoch_, search_results.total, std::move(out), + std::move(search_results.profile)}; +} + +bool ShardDocIndex::Refill(const OpArgs& op_args, const SearchParams& params, + search::SearchAlgorithm* search_algo, SearchResult* result) const { + // If no writes occured, serialize remaining entries without breaking correctness + if (result->write_epoch == write_epoch_) { + Serialize(op_args, params, absl::MakeSpan(result->docs)); + return true; + } + + // We're already on the cold path and we don't wanna gamble any more + DCHECK(!params.enable_cutoff); + + auto new_result = Search(op_args, params, search_algo); + CHECK(new_result.has_value()); // Query should be valid since it passed first step + + *result = std::move(new_result.value()); + return false; +} + +void ShardDocIndex::Serialize(const OpArgs& op_args, const SearchParams& params, + absl::Span docs) const { auto& db_slice = op_args.shard->db_slice(); - auto search_results = search_algo->Search(&indices_, params.limit_offset + params.limit_total); - if (!search_results.error.empty()) - return SearchResult{facade::ErrorReply{std::move(search_results.error)}}; + for (auto& doc : docs) { + if (!holds_alternative(doc.value)) + continue; - vector out; - out.reserve(search_results.ids.size()); + auto ref = get(doc.value); + if (!ref.requested) + return; - size_t expired_count = 0; - for (size_t i = 0; i < search_results.ids.size(); i++) { - auto key = key_index_.Get(search_results.ids[i]); - auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode()); + string key{key_index_.Get(ref.doc_id)}; + auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode()); if (!it || !IsValid(*it)) { // Item must have expired - expired_count++; + doc.value = DocResult::SerializedValue{std::move(key), {}}; continue; } auto accessor = GetAccessor(op_args.db_cntx, (*it)->second); auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields) : accessor->Serialize(base_->schema); - - auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]); - out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)}); + doc.value = DocResult::SerializedValue{std::move(key), std::move(doc_data)}; } - - return SearchResult{search_results.total - expired_count, std::move(out), - std::move(search_results.profile)}; } vector> ShardDocIndex::SearchForAggregator( diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index ebc689c4a8b5..0ec1afaf5b9f 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -25,52 +25,61 @@ using SearchDocData = absl::flat_hash_map ParseSearchFieldType(std::string_view name); std::string_view SearchFieldTypeToString(search::SchemaField::FieldType); -struct SerializedSearchDoc { - std::string key; - SearchDocData values; - search::ResultScore score; +// Represents results returned from a shard doc index that are then aggregated in the coordinator. +struct DocResult { + // Fully serialized value ready to be sent back. + struct SerializedValue { + std::string key; + SearchDocData values; + }; + + // Reference to a document that matched the query, but it's serialization was skipped as the + // document was considered unlikely to be contained in the reply. + struct DocReference { + ShardId shard_id; + search::DocId doc_id; + bool requested; + }; - bool operator<(const SerializedSearchDoc& other) const; - bool operator>=(const SerializedSearchDoc& other) const; + bool operator<(const DocResult& other) const; + bool operator>=(const DocResult& other) const; + + public: + std::variant value; + search::ResultScore score; }; struct SearchResult { - SearchResult() = default; + size_t write_epoch = 0; // Write epoch of the index on which the result was created - SearchResult(size_t total_hits, std::vector docs, - std::optional profile) - : total_hits{total_hits}, docs{std::move(docs)}, profile{std::move(profile)} { - } + size_t total_hits = 0; // total number of hits in shard + std::vector docs; // serialized documents of first hits - SearchResult(facade::ErrorReply error) : error{std::move(error)} { - } - - size_t total_hits; - std::vector docs; std::optional profile; - - std::optional error; }; struct SearchParams { using FieldReturnList = std::vector>; - // Parameters for "LIMIT offset total": select total amount documents with a specific offset from - // the whole result set - size_t limit_offset = 0; - size_t limit_total = 10; - - // Set but empty means no fields should be returned - std::optional return_fields; - std::optional sort_option; - search::QueryParams query_params; - bool IdsOnly() const { return return_fields && return_fields->empty(); } bool ShouldReturnField(std::string_view field) const; + + public: + // Parameters for "LIMIT offset total": select total amount documents with a specific offset. + size_t limit_offset = 0; + size_t limit_total = 10; + + // Pprobabilistic optimizations that avoid serializing documents unlikely to be returned. + bool enable_cutoff = false; + + std::optional return_fields; // Set but empty means no fields should be returned + + std::optional sort_option; + search::QueryParams query_params; }; // Stores basic info about a document index. @@ -123,8 +132,14 @@ class ShardDocIndex { ShardDocIndex(std::shared_ptr index); // Perform search on all indexed documents and return results. - SearchResult Search(const OpArgs& op_args, const SearchParams& params, - search::SearchAlgorithm* search_algo) const; + io::Result Search(const OpArgs& op_args, + const SearchParams& params, + search::SearchAlgorithm* search_algo) const; + + // Resolve requested doc references from the result. If no writes occured, the remaining + // entries are serialized and true is returned, otherwise a full new query is performed. + bool Refill(const OpArgs& op_args, const SearchParams& params, + search::SearchAlgorithm* search_algo, SearchResult* result) const; // Perform search and load requested values - note params might be interpreted differently. std::vector> SearchForAggregator( @@ -142,10 +157,17 @@ class ShardDocIndex { // Clears internal data. Traverses all matching documents and assigns ids. void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr); + // Serialize prefix of requested doc references. + void Serialize(const OpArgs& op_args, const SearchParams& params, + absl::Span docs) const; + private: std::shared_ptr base_; search::FieldIndices indices_; DocKeyIndex key_index_; + + // Incremented during each Add/Remove. Used to track if changes occured since last read. + size_t write_epoch_; }; // Stores shard doc indices by name on a specific shard. diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 95331ce3335c..27fce1cee0b3 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -363,101 +363,265 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, return params; } -void SendSerializedDoc(const SerializedSearchDoc& doc, ConnectionContext* cntx) { +void SendSerializedDoc(const DocResult::SerializedValue& value, ConnectionContext* cntx) { auto* rb = static_cast(cntx->reply_builder()); - rb->SendBulkString(doc.key); - rb->StartCollection(doc.values.size(), RedisReplyBuilder::MAP); - for (const auto& [k, v] : doc.values) { + rb->SendBulkString(value.key); + rb->StartCollection(value.values.size(), RedisReplyBuilder::MAP); + for (const auto& [k, v] : value.values) { rb->SendBulkString(k); rb->SendBulkString(v); } } -void ReplyWithResults(const SearchParams& params, absl::Span results, - ConnectionContext* cntx) { - size_t total_count = 0; - for (const auto& shard_docs : results) - total_count += shard_docs.total_hits; +struct MultishardSearch { + MultishardSearch(ConnectionContext* cntx, std::string_view index_name, + search::SearchAlgorithm* search_algo, SearchParams params) + : cntx_{cntx}, + rb_{static_cast(cntx->reply_builder())}, + index_name_{index_name}, + search_algo_{search_algo}, + params_{std::move(params)} { + sharded_results_.resize(shard_set->size()); + if (search_algo_->IsProfilingEnabled()) + sharded_times_.resize(shard_set->size()); + } - size_t result_count = - min(total_count - min(total_count, params.limit_offset), params.limit_total); + void RunAndReply() { + // First, run search with probabilistic optimizations enabled. + // If the result set was collected successfuly, reply. + { + params_.enable_cutoff = true; - facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()}; + if (auto err = RunSearch(); err) + return rb_->SendError(std::move(*err)); - bool ids_only = params.IdsOnly(); - size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1); + auto incomplete_shards = BuildOrder(); + if (incomplete_shards.empty()) + return Reply(); + } - auto* rb = static_cast(cntx->reply_builder()); - rb->StartArray(reply_size); - rb->SendLong(total_count); - - size_t sent = 0; - size_t to_skip = params.limit_offset; - for (const auto& shard_docs : results) { - for (const auto& serialized_doc : shard_docs.docs) { - // Scoring is not implemented yet, so we just cut them in the order they were retrieved - if (to_skip > 0) { - to_skip--; + // VLOG(0) << "Failed completness check, refilling"; + + // Otherwise, some results made it into the result set but were not serialized. + // Try refilling the requested values. If no reordering occured, reply immediately, otherwise + // try building a new order and reply if it is valid. + { + params_.enable_cutoff = false; + + auto refill_res = RunRefill(); + if (!refill_res.has_value()) + return rb_->SendError(std::move(refill_res.error())); + + if (bool no_reordering = refill_res.value(); no_reordering) + return Reply(); + + if (auto incomplete_shards = BuildOrder(); incomplete_shards.empty()) + return Reply(); + } + + VLOG(1) << "Failed refill and rebuild, re-searching"; + + // At this step all optimizations failed. Run search without any cutoffs. + { + DCHECK(!params_.enable_cutoff); + + if (auto err = RunSearch(); err) + return rb_->SendError(std::move(*err)); + + auto incomplete_shards = BuildOrder(); + DCHECK(incomplete_shards.empty()); + Reply(); + } + } + + struct ProfileInfo { + size_t total = 0; + size_t serialized = 0; + size_t cutoff = 0; + size_t hops = 0; + std::vector> profiles; + }; + + ProfileInfo GetProfileInfo() { + ProfileInfo info; + info.hops = hops_; + + for (size_t i = 0; i < sharded_results_.size(); i++) { + const auto& sd = sharded_results_[i]; + size_t serialized = count_if(sd.docs.begin(), sd.docs.end(), [](const auto& doc_res) { + return holds_alternative(doc_res.value); + }); + + info.total += sd.total_hits; + info.serialized += serialized; + info.cutoff += sd.docs.size() - serialized; + + DCHECK(sd.profile); + info.profiles.push_back({std::move(*sd.profile), sharded_times_[i]}); + } + + return info; + } + + private: + void Reply() { + size_t total_count = 0; + for (const auto& shard_docs : sharded_results_) + total_count += shard_docs.total_hits; + + auto agg_info = search_algo_->HasAggregation(); + if (agg_info && agg_info->limit) + total_count = min(total_count, *agg_info->limit); + + if (agg_info && !params_.ShouldReturnField(agg_info->alias)) + agg_info->alias = ""sv; + + size_t result_count = + min(total_count - min(total_count, params_.limit_offset), params_.limit_total); + + facade::SinkReplyBuilder::ReplyAggregator agg{cntx_->reply_builder()}; + + bool ids_only = params_.IdsOnly(); + size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1); + + rb_->StartArray(reply_size); + rb_->SendLong(total_count); + + for (size_t i = params_.limit_offset; i < ordered_docs_.size(); i++) { + auto& value = get(ordered_docs_[i]->value); + if (ids_only) { + rb_->SendBulkString(value.key); continue; } - if (sent++ >= result_count) - return; + if (agg_info && !agg_info->alias.empty()) + value.values[agg_info->alias] = absl::StrCat(get(ordered_docs_[i]->score)); - if (ids_only) - rb->SendBulkString(serialized_doc.key); - else - SendSerializedDoc(serialized_doc, cntx); + SendSerializedDoc(value, cntx_); } } -} -void ReplySorted(search::AggregationInfo agg, const SearchParams& params, - absl::Span results, ConnectionContext* cntx) { - size_t total = 0; - vector docs; - for (auto& shard_results : results) { - total += shard_results.total_hits; - for (auto& doc : shard_results.docs) { - docs.push_back(&doc); - } + // Run function f on all search indices, return first error + std::optional RunHandler( + std::function(EngineShard*, ShardDocIndex*)> f) { + hops_++; + AggregateValue> err; + cntx_->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + optional start; + if (search_algo_->IsProfilingEnabled()) + start = absl::Now(); + + if (auto* index = es->search_indices()->GetIndex(index_name_); index) + err = f(es, index); + else + err = facade::ErrorReply(string{index_name_} + ": no such index"); + + if (start.has_value()) + sharded_times_[es->shard_id()] += (absl::Now() - *start); + + return OpStatus::OK; + }); + return *err; + } + + optional RunSearch() { + cntx_->transaction->Refurbish(); + + return RunHandler([this](EngineShard* es, ShardDocIndex* index) -> optional { + auto res = index->Search(cntx_->transaction->GetOpArgs(es), params_, search_algo_); + if (!res.has_value()) + return std::move(res.error()); + sharded_results_[es->shard_id()] = std::move(res.value()); + return nullopt; + }); } - size_t agg_limit = agg.limit.value_or(total); - size_t prefix = min(params.limit_offset + params.limit_total, agg_limit); + io::Result RunRefill() { + cntx_->transaction->Refurbish(); - partial_sort(docs.begin(), docs.begin() + min(docs.size(), prefix), docs.end(), - [desc = agg.descending](const auto* l, const auto* r) { - return desc ? (*l >= *r) : (*l < *r); - }); + atomic_uint failed_refills = 0; + auto err = RunHandler([this, &failed_refills](EngineShard* es, ShardDocIndex* index) { + bool refilled = index->Refill(cntx_->transaction->GetOpArgs(es), params_, search_algo_, + &sharded_results_[es->shard_id()]); + if (!refilled) + failed_refills.fetch_add(1u); + return nullopt; + }); - docs.resize(min(docs.size(), agg_limit)); + if (err) + return nonstd::make_unexpected(std::move(*err)); + return failed_refills == 0; + } - size_t start_idx = min(params.limit_offset, docs.size()); - size_t result_count = min(docs.size() - start_idx, params.limit_total); - bool ids_only = params.IdsOnly(); - size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1); + // Build order from results collected from shards + absl::flat_hash_set BuildOrder() { + ordered_docs_.clear(); + if (auto agg = search_algo_->HasAggregation(); agg) { + BuildSortedOrder(*agg); + } else { + BuildLinearOrder(); + } + return VerifyOrderCompletness(); + } - // Clear score alias if it's excluded from return values - if (!params.ShouldReturnField(agg.alias)) - agg.alias = ""; + void BuildLinearOrder() { + size_t required = params_.limit_offset + params_.limit_total; - facade::SinkReplyBuilder::ReplyAggregator agg_reply{cntx->reply_builder()}; - auto* rb = static_cast(cntx->reply_builder()); - rb->StartArray(reply_size); - rb->SendLong(min(total, agg_limit)); - for (auto* doc : absl::MakeSpan(docs).subspan(start_idx, result_count)) { - if (ids_only) { - rb->SendBulkString(doc->key); - continue; + for (size_t idx = 0;; idx++) { + bool added = false; + for (auto& shard_result : sharded_results_) { + if (idx < shard_result.docs.size() && ordered_docs_.size() < required) { + ordered_docs_.push_back(&shard_result.docs[idx]); + added = true; + } + } + if (!added) + return; + } + } + + void BuildSortedOrder(search::AggregationInfo agg) { + for (auto& shard_result : sharded_results_) { + for (auto& doc : shard_result.docs) { + ordered_docs_.push_back(&doc); + } } - if (!agg.alias.empty() && holds_alternative(doc->score)) - doc->values[agg.alias] = absl::StrCat(get(doc->score)); + size_t agg_limit = agg.limit.value_or(ordered_docs_.size()); + size_t prefix = min(params_.limit_offset + params_.limit_total, agg_limit); + + partial_sort(ordered_docs_.begin(), ordered_docs_.begin() + min(ordered_docs_.size(), prefix), + ordered_docs_.end(), [desc = agg.descending](const auto* l, const auto* r) { + return desc ? (l->score >= r->score) : (l->score < r->score); + }); - SendSerializedDoc(*doc, cntx); + ordered_docs_.resize(min(ordered_docs_.size(), prefix)); } -} + + absl::flat_hash_set VerifyOrderCompletness() { + absl::flat_hash_set incomplete_shards; + for (auto* doc : ordered_docs_) { + if (auto* ref = get_if(&doc->value); ref) { + incomplete_shards.insert(ref->shard_id); + ref->requested = true; + } + } + return incomplete_shards; + } + + private: + ConnectionContext* cntx_; + RedisReplyBuilder* rb_; + std::string_view index_name_; + search::SearchAlgorithm* search_algo_; + SearchParams params_; + + size_t hops_ = 0; + + std::vector sharded_times_; + std::vector ordered_docs_; + std::vector sharded_results_; +}; } // namespace @@ -691,30 +855,7 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) { if (!search_algo.Init(query_str, ¶ms->query_params, sort_opt)) return cntx->SendError("Query syntax error"); - // Because our coordinator thread may not have a shard, we can't check ahead if the index exists. - atomic index_not_found{false}; - vector docs(shard_set->size()); - - cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { - if (auto* index = es->search_indices()->GetIndex(index_name); index) - docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); - else - index_not_found.store(true, memory_order_relaxed); - return OpStatus::OK; - }); - - if (index_not_found.load()) - return cntx->SendError(string{index_name} + ": no such index"); - - for (const auto& res : docs) { - if (res.error) - return cntx->SendError(*res.error); - } - - if (auto agg = search_algo.HasAggregation(); agg) - ReplySorted(std::move(*agg), *params, absl::MakeSpan(docs), cntx); - else - ReplyWithResults(*params, absl::MakeSpan(docs), cntx); + MultishardSearch{cntx, index_name, &search_algo, std::move(*params)}.RunAndReply(); } void SearchFamily::FtProfile(CmdArgList args, ConnectionContext* cntx) { @@ -733,43 +874,41 @@ void SearchFamily::FtProfile(CmdArgList args, ConnectionContext* cntx) { search_algo.EnableProfiling(); absl::Time start = absl::Now(); - atomic_uint total_docs = 0; - atomic_uint total_serialized = 0; - vector> results(shard_set->size()); + CapturingReplyBuilder crb{facade::ReplyMode::ONLY_ERR}; + MultishardSearch mss{cntx, index_name, &search_algo, std::move(*params)}; - cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { - auto* index = es->search_indices()->GetIndex(index_name); - if (!index) - return OpStatus::OK; - - auto shard_start = absl::Now(); - auto res = index->Search(t->GetOpArgs(es), *params, &search_algo); - - total_docs.fetch_add(res.total_hits); - total_serialized.fetch_add(res.docs.size()); + { + CapturingReplyBuilder::ScopeCapture capture{&crb, cntx}; + mss.RunAndReply(); + } - DCHECK(res.profile); - results[es->shard_id()] = {std::move(*res.profile), absl::Now() - shard_start}; + auto* rb = static_cast(cntx->reply_builder()); - return OpStatus::OK; - }); + auto reply = crb.Take(); + if (auto err = CapturingReplyBuilder::GetError(reply); err) + return rb->SendError(err->first, err->second); auto took = absl::Now() - start; - auto* rb = static_cast(cntx->reply_builder()); - rb->StartArray(results.size() + 1); + auto profile = mss.GetProfileInfo(); + + rb->StartArray(profile.profiles.size() + 1); // General stats - rb->StartCollection(3, RedisReplyBuilder::MAP); + rb->StartCollection(5, RedisReplyBuilder::MAP); rb->SendBulkString("took"); rb->SendLong(absl::ToInt64Microseconds(took)); rb->SendBulkString("hits"); - rb->SendLong(total_docs); + rb->SendLong(profile.total); rb->SendBulkString("serialized"); - rb->SendLong(total_serialized); + rb->SendLong(profile.serialized); + rb->SendSimpleString("cutoff"); + rb->SendLong(profile.cutoff); + rb->SendSimpleString("hops"); + rb->SendLong(profile.hops); // Per-shard stats - for (const auto& [profile, shard_took] : results) { + for (const auto& [profile, shard_took] : profile.profiles) { rb->StartCollection(2, RedisReplyBuilder::MAP); rb->SendBulkString("took"); rb->SendLong(absl::ToInt64Microseconds(shard_took)); diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 7d86bf84f7d0..6a01feb2dd92 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -603,7 +603,8 @@ TEST_F(SearchFamilyTest, FtProfile) { const auto& top_level = resp.GetVec(); EXPECT_EQ(top_level.size(), shard_set->size() + 1); - EXPECT_THAT(top_level[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _)); + EXPECT_THAT(top_level[0].GetVec(), + ElementsAre("took", _, "hits", _, "serialized", _, "cutoff", _, "hops", _)); for (size_t sid = 0; sid < shard_set->size(); sid++) { const auto& shard_resp = top_level[sid + 1].GetVec(); @@ -615,6 +616,167 @@ TEST_F(SearchFamilyTest, FtProfile) { } } +vector> FillShard(ShardId sid, string_view prefix, size_t num, size_t idx = 0) { + vector> out; + size_t entries = 0; + while (entries < num) { + auto key = absl::StrCat(prefix, idx++); + if (Shard(key, shard_set->size()) == sid) { + out.emplace_back(vector{"hset", key, "idx", to_string(idx)}); + entries++; + } + } + return out; +} + +// Check basic multi shard search cuts off a big portion of the results +TEST_F(SearchFamilyTest, MultiShardBalanced) { + Run({"ft.create", "i1", "schema", "idx", "numeric"}); + + // Fill two shards with 100 values each + for (auto cmd : FillShard(0, "doc0-", 100)) + Run(absl::MakeSpan(cmd)); + for (auto cmd : FillShard(1, "doc1-", 100)) + Run(absl::MakeSpan(cmd)); + + auto resp = Run({"ft.profile", "i1", "SEARCH", "QUERY", "*", "LIMIT", "0", "50"}); + + auto stats = resp.GetVec()[0].GetVec(); + + // Make sure no refill was needed + EXPECT_EQ(stats[8], "hops"); + EXPECT_THAT(stats[9], IntArg(1)); + + // Make sure at least around half of serialization was cut off + EXPECT_EQ(stats[4], "serialized"); + EXPECT_LE(stats[5].GetInt(), 55); + EXPECT_EQ(stats[6], "cutoff"); + EXPECT_GE(stats[7].GetInt(), 45); +} + +// Simulate an uneven distribution which forces multi shard search to perform a refill +TEST_F(SearchFamilyTest, MultiShardRefill) { + Run({"ft.create", "i1", "schema", "idx", "numeric"}); + + // Place 100 keys ONLY on shard 0 + for (auto cmd : FillShard(0, "doc", 100)) + Run(absl::MakeSpan(cmd)); + + // This will fail the probabilistc bound as well as the refill phase, + // but should still succeed to select enough entries + for (size_t limit : {10, 20, 50}) { + auto resp = Run({"ft.search", "i1", "*", "LIMIT", "0", to_string(limit)}); + EXPECT_THAT(resp.GetVec().size(), 2 * limit + 1); + EXPECT_THAT(resp.GetVec()[0], IntArg(100)); + + resp = Run({"ft.profile", "i1", "SEARCH", "QUERY", "*", "LIMIT", "0", to_string(limit)}); + auto stats = resp.GetVec()[0].GetVec(); + // Make sure only one additional hop was needed for refill + EXPECT_EQ(stats[8], "hops"); + EXPECT_THAT(stats[9], IntArg(2)) << "On limit " << limit; + EXPECT_EQ(stats[6], "cutoff"); + EXPECT_THAT(stats[7], IntArg(0)); + } +} + +// Simulate an uneven distribution which forces multi shard search to perform a refill, +// but the refill is interrupted by constant updates, which should lead to a full repeated query +// on a single shard (the interrupted one). After the repeated query, a successful order is built. +TEST_F(SearchFamilyTest, MultiShardRefillRefresh) { + Run({"ft.create", "i1", "schema", "idx", "numeric"}); + + // Place 100 keys ONLY on shard 0 + for (auto cmd : FillShard(0, "doc", 100)) + Run(absl::MakeSpan(cmd)); + + atomic_bool keep_running = true; + string_view key = "doc1"; + EXPECT_EQ(Shard(key, shard_set->size()), 0); + auto fb = pp_->at(2)->LaunchFiber([this, &keep_running, key]() { + while (keep_running.load()) + Run("pressure", {"hset", key, "updates", "more-and-more!"}); + }); + + auto resp = Run({"ft.profile", "i1", "SEARCH", "QUERY", "*", "LIMIT", "0", "10"}); + auto stats = resp.GetVec()[0].GetVec(); + + // Make sure refill didn't succeed because of constant updates + EXPECT_EQ(stats[8], "hops"); + EXPECT_THAT(stats[9], IntArg(2)); + + keep_running.store(false); + ThisFiber::SleepFor(10ms); + fb.Join(); +} + +// Simulate multi shard worst case. A refill needs to be performed, but fails on one shard due to +// constant writes. After a full repeated query, one element less is returned, which leads to an +// invalid order (because the other shard refilled one less than needed). A full repeated query is +// needed, so a total of 3 hops are performed. +// TODO: Test will be invalidated with fine-grained refills +TEST_F(SearchFamilyTest, MultiShardRefillRepeat) { + Run({"ft.create", "i1", "schema", "idx", "numeric"}); + + // Place 100 keys ONLY on shard 0 + for (auto cmd : FillShard(0, "doc", 100)) + Run(absl::MakeSpan(cmd)); + + // Place a single key on shard 1 + auto key = "the-destroyer"; + EXPECT_EQ(Shard(key, shard_set->size()), 1); + Run({"hset", key, "idx", "1"}); + + atomic_bool keep_running = true; + auto fb = pp_->at(2)->LaunchFiber([this, &keep_running, key]() { + size_t idx = 0; + while (keep_running.load()) { + if (idx++ % 2 == 0) + Run("pressure", {"del", key}); + else + Run("pressure", {"hset", key, "idx", "1"}); + } + }); + + bool had_3hops = false; + for (size_t tries = 0; tries < 100; tries++) { + auto resp = Run({"ft.profile", "i1", "SEARCH", "QUERY", "*", "LIMIT", "0", "10"}); + auto stats = resp.GetVec()[0].GetVec(); + + EXPECT_EQ(stats[8], "hops"); + if (stats[9].GetInt() == 2) + continue; + + EXPECT_THAT(stats[9], IntArg(3)); + had_3hops = true; + break; + } + + EXPECT_TRUE(had_3hops) << "Failed probabilstic test :("; + + keep_running.store(false); + ThisFiber::SleepFor(10ms); + fb.Join(); +} + +TEST_F(SearchFamilyTest, MultiShardAggregation) { + // Place 50 keys on shards 0 and 1, but values on shard 1 have a larger value + for (auto cmd : FillShard(0, "doc", 50, 0)) + Run(absl::MakeSpan(cmd)); + + for (auto cmd : FillShard(1, "doc", 50, 100)) + Run(absl::MakeSpan(cmd)); + + Run({"ft.create", "i1", "schema", "idx", "numeric", "sortable"}); + + // The distribution is completely unbalanced, so getting the largest vlaues should require two + // hops + auto resp = Run( + {"ft.profile", "i1", "SEARCH", "QUERY", "*", "LIMIT", "0", "20", "SORTBY", "idx", "DESC"}); + auto stats = resp.GetVec()[0].GetVec(); + EXPECT_EQ(stats[8], "hops"); + EXPECT_THAT(stats[9], IntArg(2)); +} + TEST_F(SearchFamilyTest, SimpleExpiry) { EXPECT_EQ(Run({"ft.create", "i1", "schema", "title", "text", "expires-in", "numeric"}), "OK"); @@ -628,10 +790,15 @@ TEST_F(SearchFamilyTest, SimpleExpiry) { EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("d:1", "d:2", "d:3")); - shard_set->TEST_EnableHeartBeat(); - + // Expired documents are still included in idlist AdvanceTime(60); - ThisFiber::SleepFor(5ms); // Give heartbeat time to delete expired doc + EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("d:1", "d:2", "d:3")); + + shard_set->TEST_EnableHeartBeat(); + for (size_t i = 0; i < 5; i++) { // Give heartbeat time to delete expired doc + ThisFiber::SleepFor(2ms); + Run({"incr", "run heatbeat run"}); + } EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("d:1", "d:3")); AdvanceTime(60); diff --git a/src/server/server_family.cc b/src/server/server_family.cc index ac2a16afd730..5f7c4725bd00 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -132,6 +132,7 @@ ABSL_DECLARE_FLAG(uint32_t, hz); ABSL_DECLARE_FLAG(bool, tls); ABSL_DECLARE_FLAG(string, tls_ca_cert_file); ABSL_DECLARE_FLAG(string, tls_ca_cert_dir); +ABSL_DECLARE_FLAG(int, replica_priority); bool AbslParseFlag(std::string_view in, ReplicaOfFlag* flag, std::string* err) { #define RETURN_ON_ERROR(cond, m) \ @@ -2284,6 +2285,8 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) { append("master_last_io_seconds_ago", rinfo.master_last_io_sec); append("master_sync_in_progress", rinfo.full_sync_in_progress); append("master_replid", rinfo.master_id); + append("slave_priority", GetFlag(FLAGS_replica_priority)); + append("slave_read_only", 1); }; replication_info_cb(replica_->GetInfo()); for (const auto& replica : cluster_replicas_) { diff --git a/src/server/tiering/disk_storage.cc b/src/server/tiering/disk_storage.cc index 215f2517eed7..cbd2e9ee3b42 100644 --- a/src/server/tiering/disk_storage.cc +++ b/src/server/tiering/disk_storage.cc @@ -15,6 +15,8 @@ using namespace ::dfly::tiering::literals; +ABSL_FLAG(bool, backing_file_direct, false, "If true uses O_DIRECT to open backing files"); + ABSL_FLAG(uint64_t, registered_buffer_size, 512_KB, "Size of registered buffer for IoUring fixed read/writes"); @@ -57,18 +59,42 @@ void ReturnBuf(UringBuf buf) { DestroyTmpBuf(buf); } +constexpr off_t kInitialSize = 1UL << 28; // 256MB + +template std::error_code DoFiberCall(void (SubmitEntry::*c)(Ts...), Ts... args) { + auto* proactor = static_cast(ProactorBase::me()); + FiberCall fc(proactor); + (fc.operator->()->*c)(std::forward(args)...); + FiberCall::IoResult io_res = fc.Get(); + return io_res < 0 ? std::error_code{-io_res, std::system_category()} : std::error_code{}; +} + } // anonymous namespace DiskStorage::DiskStorage(size_t max_size) : max_size_(max_size) { } std::error_code DiskStorage::Open(std::string_view path) { - RETURN_ON_ERR(io_mgr_.Open(path)); - alloc_.AddStorage(0, io_mgr_.Span()); - DCHECK_EQ(ProactorBase::me()->GetKind(), ProactorBase::IOURING); - auto* up = static_cast(ProactorBase::me()); + CHECK(!backing_file_); + + int kFlags = O_CREAT | O_RDWR | O_TRUNC | O_CLOEXEC; + if (absl::GetFlag(FLAGS_backing_file_direct)) + kFlags |= O_DIRECT; + auto res = OpenLinux(path, kFlags, 0666); + if (!res) + return res.error(); + backing_file_ = std::move(res.value()); + + int fd = backing_file_->fd(); + RETURN_ON_ERR(DoFiberCall(&SubmitEntry::PrepFallocate, fd, 0, 0L, kInitialSize)); + RETURN_ON_ERR(DoFiberCall(&SubmitEntry::PrepFadvise, fd, 0L, 0L, POSIX_FADV_RANDOM)); + + size_ = kInitialSize; + alloc_.AddStorage(0, size_); + + auto* up = static_cast(ProactorBase::me()); if (int io_res = up->RegisterBuffers(absl::GetFlag(FLAGS_registered_buffer_size)); io_res < 0) return std::error_code{-io_res, std::system_category()}; @@ -77,10 +103,11 @@ std::error_code DiskStorage::Open(std::string_view path) { void DiskStorage::Close() { using namespace std::chrono_literals; - while (pending_ops_ > 0) + while (pending_ops_ > 0 || grow_pending_) util::ThisFiber::SleepFor(10ms); - io_mgr_.Shutdown(); + backing_file_->Close(); + backing_file_.reset(); } void DiskStorage::Read(DiskSegment segment, ReadCb cb) { @@ -98,7 +125,10 @@ void DiskStorage::Read(DiskSegment segment, ReadCb cb) { }; pending_ops_++; - io_mgr_.ReadAsync(segment.offset, buf, std::move(io_cb)); + if (buf.buf_idx) + backing_file_->ReadFixedAsync(buf.bytes, segment.offset, *buf.buf_idx, std::move(io_cb)); + else + backing_file_->ReadAsync(buf.bytes, segment.offset, std::move(io_cb)); } void DiskStorage::MarkAsFree(DiskSegment segment) { @@ -115,17 +145,9 @@ std::error_code DiskStorage::Stash(io::Bytes bytes, StashCb cb) { // If we've run out of space, block and grow as much as needed if (offset < 0) { - size_t start = io_mgr_.Span(); - size_t grow_size = -offset; + RETURN_ON_ERR(Grow(-offset)); - if (alloc_.capacity() + grow_size >= max_size_) - return std::make_error_code(std::errc::no_space_on_device); - - RETURN_ON_ERR(io_mgr_.Grow(grow_size)); - - alloc_.AddStorage(start, grow_size); offset = alloc_.Malloc(bytes.size()); - if (offset < 0) // we can't fit it even after resizing return std::make_error_code(std::errc::file_too_large); } @@ -145,7 +167,10 @@ std::error_code DiskStorage::Stash(io::Bytes bytes, StashCb cb) { }; pending_ops_++; - io_mgr_.WriteAsync(offset, buf, std::move(io_cb)); + if (buf.buf_idx) + backing_file_->WriteFixedAsync(buf.bytes, offset, *buf.buf_idx, std::move(io_cb)); + else + backing_file_->WriteAsync(buf.bytes, offset, std::move(io_cb)); return {}; } @@ -153,4 +178,22 @@ DiskStorage::Stats DiskStorage::GetStats() const { return {alloc_.allocated_bytes(), alloc_.capacity()}; } +std::error_code DiskStorage::Grow(off_t grow_size) { + off_t start = size_; + + if (off_t(alloc_.capacity()) + grow_size >= max_size_) + return std::make_error_code(std::errc::no_space_on_device); + + if (std::exchange(grow_pending_, true)) + return std::make_error_code(std::errc::operation_in_progress); + + auto err = DoFiberCall(&SubmitEntry::PrepFallocate, backing_file_->fd(), 0, size_, grow_size); + grow_pending_ = false; + RETURN_ON_ERR(err); + + size_ += grow_size; + alloc_.AddStorage(start, grow_size); + return {}; +} + } // namespace dfly::tiering diff --git a/src/server/tiering/disk_storage.h b/src/server/tiering/disk_storage.h index bb968e6096a4..0aecd699fa70 100644 --- a/src/server/tiering/disk_storage.h +++ b/src/server/tiering/disk_storage.h @@ -4,13 +4,12 @@ #pragma once -#include #include #include "io/io.h" #include "server/tiering/common.h" #include "server/tiering/external_alloc.h" -#include "server/tiering/io_mgr.h" +#include "util/fibers/uring_file.h" namespace dfly::tiering { @@ -45,9 +44,14 @@ class DiskStorage { Stats GetStats() const; private: - size_t pending_ops_ = 0; - size_t max_size_; - IoMgr io_mgr_; + std::error_code Grow(off_t grow_size); + + private: + off_t size_, max_size_; + size_t pending_ops_ = 0; // number of ongoing ops for safe shutdown + bool grow_pending_ = false; + std::unique_ptr backing_file_; + ExternalAllocator alloc_; }; diff --git a/src/server/tiering/io_mgr.cc b/src/server/tiering/io_mgr.cc deleted file mode 100644 index d3700baf7949..000000000000 --- a/src/server/tiering/io_mgr.cc +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. -// See LICENSE for licensing terms. -// - -#include "server/tiering/io_mgr.h" - -#include -#include - -#include "base/flags.h" -#include "base/logging.h" -#include "facade/facade_types.h" -#include "server/tiering/common.h" -#include "util/fibers/uring_proactor.h" - -ABSL_FLAG(bool, backing_file_direct, false, "If true uses O_DIRECT to open backing files"); - -namespace dfly::tiering { - -using namespace std; -using namespace util; -using namespace facade; - -using Proactor = fb2::UringProactor; -using fb2::ProactorBase; -using fb2::SubmitEntry; - -constexpr size_t kInitialSize = 1UL << 28; // 256MB - -error_code IoMgr::Open(std::string_view path) { - CHECK(!backing_file_); - - int kFlags = O_CREAT | O_RDWR | O_TRUNC | O_CLOEXEC; - if (absl::GetFlag(FLAGS_backing_file_direct)) { - kFlags |= O_DIRECT; - } - auto res = fb2::OpenLinux(path, kFlags, 0666); - if (!res) - return res.error(); - backing_file_ = std::move(res.value()); - Proactor* proactor = (Proactor*)ProactorBase::me(); - { - fb2::FiberCall fc(proactor); - fc->PrepFallocate(backing_file_->fd(), 0, 0, kInitialSize); - fb2::FiberCall::IoResult io_res = fc.Get(); - if (io_res < 0) { - return error_code{-io_res, system_category()}; - } - } - { - fb2::FiberCall fc(proactor); - fc->PrepFadvise(backing_file_->fd(), 0, 0, POSIX_FADV_RANDOM); - fb2::FiberCall::IoResult io_res = fc.Get(); - if (io_res < 0) { - return error_code{-io_res, system_category()}; - } - } - sz_ = kInitialSize; - return error_code{}; -} - -error_code IoMgr::Grow(size_t len) { - Proactor* proactor = (Proactor*)ProactorBase::me(); - - if (exchange(grow_progress_, true)) - return make_error_code(errc::operation_in_progress); - - fb2::FiberCall fc(proactor); - fc->PrepFallocate(backing_file_->fd(), 0, sz_, len); - Proactor::IoResult res = fc.Get(); - - grow_progress_ = false; - - if (res == 0) { - sz_ += len; - return {}; - } else { - return std::error_code(-res, std::iostream_category()); - } -} - -void IoMgr::WriteAsync(size_t offset, util::fb2::UringBuf buf, WriteCb cb) { - DCHECK(!buf.bytes.empty()); - - auto* proactor = static_cast(ProactorBase::me()); - auto ring_cb = [cb = std::move(cb)](auto*, Proactor::IoResult res, uint32_t flags) { cb(res); }; - - SubmitEntry se = proactor->GetSubmitEntry(std::move(ring_cb), 0); - if (buf.buf_idx) - se.PrepWriteFixed(backing_file_->fd(), buf.bytes.data(), buf.bytes.size(), offset, - *buf.buf_idx); - else - se.PrepWrite(backing_file_->fd(), buf.bytes.data(), buf.bytes.size(), offset); -} - -void IoMgr::ReadAsync(size_t offset, util::fb2::UringBuf buf, ReadCb cb) { - DCHECK(!buf.bytes.empty()); - - auto* proactor = static_cast(ProactorBase::me()); - auto ring_cb = [cb = std::move(cb)](auto*, Proactor::IoResult res, uint32_t flags) { cb(res); }; - - SubmitEntry se = proactor->GetSubmitEntry(std::move(ring_cb), 0); - if (buf.buf_idx) - se.PrepReadFixed(backing_file_->fd(), buf.bytes.data(), buf.bytes.size(), offset, *buf.buf_idx); - else - se.PrepRead(backing_file_->fd(), buf.bytes.data(), buf.bytes.size(), offset); -} - -void IoMgr::Shutdown() { - while (grow_progress_) { - ThisFiber::SleepFor(200us); // TODO: hacky for now. - } - backing_file_->Close(); - backing_file_.reset(); -} - -} // namespace dfly::tiering diff --git a/src/server/tiering/io_mgr.h b/src/server/tiering/io_mgr.h deleted file mode 100644 index c88102af5f78..000000000000 --- a/src/server/tiering/io_mgr.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. -// See LICENSE for licensing terms. -// - -#pragma once - -#include -#include - -#include "server/common.h" -#include "util/fibers/uring_file.h" -#include "util/fibers/uring_proactor.h" - -namespace dfly::tiering { - -class IoMgr { - public: - // first arg - io result. - // using WriteCb = fu2::function_base; - using WriteCb = std::function; - - using ReadCb = std::function; - - // blocks until all the pending requests are finished. - void Shutdown(); - - std::error_code Open(std::string_view path); - - // Try growing file by that length. Return error if growth failed. - std::error_code Grow(size_t len); - - // Write into offset from src and call cb once done. The callback is guaranteed to be invoked in - // any error case for cleanup. The src buffer must outlive the call, until cb is resolved. - void WriteAsync(size_t offset, util::fb2::UringBuf src, WriteCb cb); - - // Read into dest and call cb once read. The callback is guaranteed to be invoked in any error - // case for cleanup. The dest buffer must outlive the call, until cb is resolved. - void ReadAsync(size_t offset, util::fb2::UringBuf dest, ReadCb cb); - - // Total file span - size_t Span() const { - return sz_; - } - - bool grow_pending() const { - return grow_progress_; - } - - private: - std::unique_ptr backing_file_; - size_t sz_ = 0; - - bool grow_progress_ = false; -}; - -} // namespace dfly::tiering diff --git a/src/server/tiering/op_manager.cc b/src/server/tiering/op_manager.cc index 18374f588fc6..0ff639dcb4df 100644 --- a/src/server/tiering/op_manager.cc +++ b/src/server/tiering/op_manager.cc @@ -11,6 +11,7 @@ #include "io/io.h" #include "server/tiering/common.h" #include "server/tiering/disk_storage.h" +#include "util/fibers/fibers.h" namespace dfly::tiering { namespace { diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 1a17ca2dbde0..002a20988fec 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -627,6 +627,7 @@ void Transaction::RunCallback(EngineShard* shard) { DCHECK_EQ(shard, EngineShard::tlocal()); RunnableResult result; + shard->db_slice().LockChangeCb(); try { result = (*cb_ptr_)(this, shard); @@ -664,7 +665,10 @@ void Transaction::RunCallback(EngineShard* shard) { // Log to journal only once the command finished running if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) { LogAutoJournalOnShard(shard, result); + shard->db_slice().UnlockChangeCb(); MaybeInvokeTrackingCb(); + } else { + shard->db_slice().UnlockChangeCb(); } } @@ -1247,9 +1251,11 @@ OpStatus Transaction::RunSquashedMultiCb(RunnableType cb) { DCHECK_EQ(unique_shard_cnt_, 1u); auto* shard = EngineShard::tlocal(); + shard->db_slice().LockChangeCb(); auto result = cb(this, shard); shard->db_slice().OnCbFinish(); LogAutoJournalOnShard(shard, result); + shard->db_slice().UnlockChangeCb(); MaybeInvokeTrackingCb(); DCHECK_EQ(result.flags, 0); // if it's sophisticated, we shouldn't squash it diff --git a/src/server/transaction.h b/src/server/transaction.h index 2d36d40f508e..766266c68e44 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -378,6 +378,9 @@ class Transaction { " local_res=" + std::to_string(int(local_result_)); } + void EnableShard(ShardId sid); + void EnableAllShards(); + private: // Holds number of locks for each IntentLock::Mode: shared and exlusive. struct LockCnt { @@ -495,9 +498,6 @@ class Transaction { // Init with a set of keys. void InitByKeys(const KeyIndex& keys); - void EnableShard(ShardId sid); - void EnableAllShards(); - // Build shard index by distributing the arguments by shards based on the key index. void BuildShardIndex(const KeyIndex& keys, std::vector* out); diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index f044269e91c7..efc17787205a 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -14,53 +14,53 @@ async def test_acl_setuser(async_client): await async_client.execute_command("ACL SETUSER kostas") result = await async_client.execute_command("ACL list") assert 2 == len(result) - assert "user kostas off nopass -@all" in result + assert "user kostas off -@all" in result await async_client.execute_command("ACL SETUSER kostas ON") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all" in result + assert "user kostas on -@all" in result await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin") result = await async_client.execute_command("ACL list") # TODO consider printing to lowercase - assert "user kostas on nopass -@all +@list +@string +@admin" in result + assert "user kostas on -@all +@list +@string +@admin" in result await async_client.execute_command("ACL SETUSER kostas -@list -@admin") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all +@string -@list -@admin" in result + assert "user kostas on -@all +@string -@list -@admin" in result # mix and match await async_client.execute_command("ACL SETUSER kostas +@list -@string") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all -@admin +@list -@string" in result + assert "user kostas on -@all -@admin +@list -@string" in result # mix and match interleaved await async_client.execute_command("ACL SETUSER kostas +@set -@set +@set") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all -@admin +@list -@string +@set" in result + assert "user kostas on -@all -@admin +@list -@string +@set" in result await async_client.execute_command("ACL SETUSER kostas +@all") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set +@all" in result + assert "user kostas on -@admin +@list -@string +@set +@all" in result # commands await async_client.execute_command("ACL SETUSER kostas +set +get +hset") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set +@all +set +get +hset" in result + assert "user kostas on -@admin +@list -@string +@set +@all +set +get +hset" in result await async_client.execute_command("ACL SETUSER kostas -set -get +hset") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set +@all -set -get +hset" in result + assert "user kostas on -@admin +@list -@string +@set +@all -set -get +hset" in result # interleaved await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set -set -hset -get -@all" in result + assert "user kostas on -@admin +@list -@string +@set -set -hset -get -@all" in result # interleaved with categories await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list +@set -hset -@all +@string -get +set" in result + assert "user kostas on -@admin +@list +@set -hset -@all +@string -get +set" in result @pytest.mark.asyncio @@ -332,7 +332,7 @@ async def test_good_acl_file(df_local_factory, tmp_dir): await client.execute_command("ACL LOAD") result = await client.execute_command("ACL list") assert 2 == len(result) - assert "user MrFoo on ea71c25a7a60224 -@all" in result + assert "user MrFoo on #ea71c25a7a60224 -@all" in result assert "user default on nopass ~* +@all" in result await client.execute_command("ACL DELUSER MrFoo") @@ -342,9 +342,9 @@ async def test_good_acl_file(df_local_factory, tmp_dir): result = await client.execute_command("ACL list") assert 4 == len(result) - assert "user roy on ea71c25a7a60224 -@all +@string +hset" in result - assert "user shahar off ea71c25a7a60224 -@all +@set" in result - assert "user vlad off nopass ~foo ~bar* -@all +@string" in result + assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result + assert "user shahar off #ea71c25a7a60224 -@all +@set" in result + assert "user vlad off ~foo ~bar* -@all +@string" in result assert "user default on nopass ~* +@all" in result result = await client.execute_command("ACL DELUSER shahar") @@ -356,8 +356,8 @@ async def test_good_acl_file(df_local_factory, tmp_dir): result = await client.execute_command("ACL list") assert 3 == len(result) - assert "user roy on ea71c25a7a60224 -@all +@string +hset" in result - assert "user vlad off nopass ~foo ~bar* -@all +@string" in result + assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result + assert "user vlad off ~foo ~bar* -@all +@string" in result assert "user default on nopass ~* +@all" in result await client.close() diff --git a/tests/dragonfly/cluster_test.py b/tests/dragonfly/cluster_test.py index aa2f65fc8f7f..a94275fbe8cf 100644 --- a/tests/dragonfly/cluster_test.py +++ b/tests/dragonfly/cluster_test.py @@ -19,6 +19,16 @@ BASE_PORT = 30001 +async def assert_eventually(e): + iterations = 0 + while True: + if await e(): + return + iterations += 1 + assert iterations < 500 + await asyncio.sleep(0.1) + + class RedisClusterNode: def __init__(self, port): self.port = port @@ -83,7 +93,6 @@ class NodeInfo: client: aioredis.Redis admin_client: aioredis.Redis slots: list - next_slots: list migrations: list id: str @@ -95,7 +104,6 @@ async def create_node_info(instance): client=instance.client(), admin_client=admin_client, slots=[], - next_slots=[], migrations=[], id=await get_node_id(admin_client), ) @@ -1026,6 +1034,59 @@ async def test_config_consistency(df_local_factory: DflyInstanceFactory): await close_clients(*[node.client for node in nodes], *[node.admin_client for node in nodes]) +@dfly_args({"proactor_threads": 4, "cluster_mode": "yes"}) +async def test_cluster_flushall_during_migration( + df_local_factory: DflyInstanceFactory, df_seeder_factory +): + # Check data migration from one node to another + instances = [ + df_local_factory.create( + port=BASE_PORT + i, + admin_port=BASE_PORT + i + 1000, + vmodule="cluster_family=9,cluster_slot_migration=9,outgoing_slot_migration=9", + logtostdout=True, + ) + for i in range(2) + ] + + df_local_factory.start_all(instances) + + nodes = [(await create_node_info(instance)) for instance in instances] + nodes[0].slots = [(0, 16383)] + nodes[1].slots = [] + + await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + + seeder = df_seeder_factory.create(keys=10_000, port=nodes[0].instance.port, cluster_mode=True) + await seeder.run(target_deviation=0.1) + + nodes[0].migrations.append( + MigrationInfo("127.0.0.1", nodes[1].instance.admin_port, [(0, 16383)], nodes[1].id) + ) + + logging.debug("Start migration") + await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + + await nodes[0].client.execute_command("flushall") + + assert "FINISHED" not in await nodes[1].admin_client.execute_command( + "DFLYCLUSTER", "SLOT-MIGRATION-STATUS", nodes[0].id + ), "Weak test case - finished migration too early" + + await wait_for_status(nodes[0].admin_client, nodes[1].id, "FINISHED") + + logging.debug("Finalizing migration") + nodes[0].migrations = [] + nodes[0].slots = [] + nodes[1].slots = [(0, 16383)] + await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + logging.debug("Migration finalized") + + assert await nodes[0].client.dbsize() == 0 + + await close_clients(*[node.client for node in nodes], *[node.admin_client for node in nodes]) + + @dfly_args({"proactor_threads": 4, "cluster_mode": "yes"}) async def test_cluster_data_migration(df_local_factory: DflyInstanceFactory): # Check data migration from one node to another @@ -1065,12 +1126,12 @@ async def test_cluster_data_migration(df_local_factory: DflyInstanceFactory): await nodes[0].admin_client.execute_command( "DFLYCLUSTER", "SLOT-MIGRATION-STATUS", nodes[1].id ) - ).startswith(f"""out {nodes[1].id} FINISHED keys:7""") + ).startswith(f"out {nodes[1].id} FINISHED keys:7") assert ( await nodes[1].admin_client.execute_command( "DFLYCLUSTER", "SLOT-MIGRATION-STATUS", nodes[0].id ) - ).startswith(f"""in {nodes[0].id} FINISHED keys:7""") + ).startswith(f"in {nodes[0].id} FINISHED keys:7") nodes[0].migrations = [] nodes[0].slots = [(0, 2999)] @@ -1088,6 +1149,58 @@ async def test_cluster_data_migration(df_local_factory: DflyInstanceFactory): await close_clients(*[node.client for node in nodes], *[node.admin_client for node in nodes]) +@dfly_args({"proactor_threads": 2, "cluster_mode": "yes", "cache_mode": "true"}) +async def test_migration_with_key_ttl(df_local_factory): + instances = [ + df_local_factory.create(port=BASE_PORT + i, admin_port=BASE_PORT + i + 1000) + for i in range(2) + ] + + df_local_factory.start_all(instances) + + nodes = [(await create_node_info(instance)) for instance in instances] + nodes[0].slots = [(0, 16383)] + nodes[1].slots = [] + + await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + + await nodes[0].client.execute_command("set k_with_ttl v1 EX 2") + await nodes[0].client.execute_command("set k_without_ttl v2") + await nodes[0].client.execute_command("set k_sticky v3") + assert await nodes[0].client.execute_command("stick k_sticky") == 1 + + nodes[0].migrations.append( + MigrationInfo("127.0.0.1", instances[1].port, [(0, 16383)], nodes[1].id) + ) + logging.debug("Start migration") + await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + + await wait_for_status(nodes[0].admin_client, nodes[1].id, "FINISHED") + + nodes[0].migrations = [] + nodes[0].slots = [] + nodes[1].slots = [(0, 16383)] + logging.debug("finalize migration") + await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + + assert await nodes[1].client.execute_command("get k_with_ttl") == "v1" + assert await nodes[1].client.execute_command("get k_without_ttl") == "v2" + assert await nodes[1].client.execute_command("get k_sticky") == "v3" + assert await nodes[1].client.execute_command("ttl k_with_ttl") > 0 + assert await nodes[1].client.execute_command("ttl k_without_ttl") == -1 + assert await nodes[1].client.execute_command("stick k_sticky") == 0 # Sticky bit already set + + await asyncio.sleep(2) # Force expiration + + assert await nodes[1].client.execute_command("get k_with_ttl") == None + assert await nodes[1].client.execute_command("get k_without_ttl") == "v2" + assert await nodes[1].client.execute_command("ttl k_with_ttl") == -2 + assert await nodes[1].client.execute_command("ttl k_without_ttl") == -1 + assert await nodes[1].client.execute_command("stick k_sticky") == 0 + + await close_clients(*[node.client for node in nodes], *[node.admin_client for node in nodes]) + + @dfly_args({"proactor_threads": 4, "cluster_mode": "yes"}) async def test_network_disconnect_during_migration(df_local_factory, df_seeder_factory): instances = [ @@ -1179,8 +1292,14 @@ async def test_cluster_fuzzymigration( # Counter that pushes values to a list async def list_counter(key, client: aioredis.RedisCluster): - for i in itertools.count(start=1): - await client.lpush(key, i) + try: + for i in itertools.count(start=1): + await client.lpush(key, i) + except asyncio.exceptions.CancelledError: + return + # TODO find the reason of TTL exhausted error and is it possible to fix it + except redis.exceptions.ClusterError: + return # Start ten counters counter_keys = [f"_counter{i}" for i in range(10)] @@ -1224,42 +1343,51 @@ async def list_counter(key, client: aioredis.RedisCluster): ) ) - nodes[dest_idx].next_slots.extend(dest_slots) - - keeping = node.slots[num_outgoing:] - node.next_slots.extend(keeping) - logging.debug("start migrations") await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) - iterations = 0 - while True: - is_all_finished = True + logging.debug("finish migrations") + + async def all_finished(): + res = True for node in nodes: states = await node.admin_client.execute_command("DFLYCLUSTER", "SLOT-MIGRATION-STATUS") logging.debug(states) - is_all_finished = is_all_finished and ( - all("FINISHED" in s for s in states) or states == "NO_STATE" - ) - - if is_all_finished: - break - - iterations += 1 - assert iterations < 500 - - await asyncio.sleep(0.1) + for state in states: + parsed_state = re.search("([a-z]+) ([a-z0-9]+) ([A-Z]+)", state) + if parsed_state == None: + continue + direction, node_id, st = parsed_state.group(1, 2, 3) + if direction == "out": + if st == "FINISHED": + m_id = [id for id, x in enumerate(node.migrations) if x.node_id == node_id][ + 0 + ] + node.slots = [s for s in node.slots if s not in node.migrations[m_id].slots] + target_node = [n for n in nodes if n.id == node_id][0] + target_node.slots.extend(node.migrations[m_id].slots) + print( + "FINISH migration", + node.id, + ":", + node.migrations[m_id].node_id, + " slots:", + node.migrations[m_id].slots, + ) + node.migrations.pop(m_id) + await push_config( + json.dumps(generate_config(nodes)), + [node.admin_client for node in nodes], + ) + else: + res = False + return res + + await assert_eventually(all_finished) for counter in counters: counter.cancel() - - # clean migrations - for node in nodes: - node.migrations = [] - node.slots = node.next_slots - - logging.debug("remove finished migrations") - await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) + await counter # Check counter consistency cluster_client = aioredis.RedisCluster(host="localhost", port=nodes[0].instance.port) @@ -1359,13 +1487,11 @@ async def test_cluster_migration_cancel(df_local_factory: DflyInstanceFactory): nodes[0].migrations = [] await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes]) assert SIZE == await nodes[0].client.dbsize() - while True: - db_size = await nodes[1].client.dbsize() - if 0 == db_size: - break - logging.debug(f"target dbsize is {db_size}") - logging.debug(await nodes[1].client.execute_command("KEYS", "*")) - await asyncio.sleep(0.1) + + async def node1size0(): + return await nodes[1].client.dbsize() == 0 + + await assert_eventually(node1size0) logging.debug("Reissuing migration") nodes[0].migrations.append( diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index c25ffdc90147..bbf1f2477d68 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -689,16 +689,16 @@ async def check_expire(key): await c_master.set("renamekey", "1000", px=50000) await skip_cmd() - # Check RENAME turns into DEL SET and PEXPIREAT + # Check RENAME turns into DEL and RESTORE await check_list_ooo( "RENAME renamekey renamed", - [r"DEL renamekey", r"SET renamed 1000", r"PEXPIREAT renamed (.*?)"], + [r"DEL renamekey", r"RESTORE renamed (.*?) (.*?) REPLACE ABSTTL"], ) await check_expire("renamed") - # Check RENAMENX turns into DEL SET and PEXPIREAT + # Check RENAMENX turns into DEL and RESTORE await check_list_ooo( "RENAMENX renamed renamekey", - [r"DEL renamed", r"SET renamekey 1000", r"PEXPIREAT renamekey (.*?)"], + [r"DEL renamed", r"RESTORE renamekey (.*?) (.*?) REPLACE ABSTTL"], ) await check_expire("renamekey") @@ -2164,7 +2164,7 @@ async def test_replica_reconnect(df_local_factory, break_conn): # Connect replica to master master = df_local_factory.create(proactor_threads=1) replica = df_local_factory.create( - proactor_threads=1, replica_reconnect_on_master_restart=break_conn + proactor_threads=1, break_replication_on_master_restart=break_conn ) df_local_factory.start_all([master, replica]) diff --git a/tests/dragonfly/sentinel_test.py b/tests/dragonfly/sentinel_test.py index d65137c84938..8395225ed24a 100644 --- a/tests/dragonfly/sentinel_test.py +++ b/tests/dragonfly/sentinel_test.py @@ -8,6 +8,7 @@ from datetime import datetime from sys import stderr import logging +from . import dfly_args # Helper function to parse some sentinel cli commands output as key value dictionaries. @@ -63,6 +64,7 @@ def start(self): f"port {self.port}", f"sentinel monitor {self.default_deployment} 127.0.0.1 {self.initial_master_port} 1", f"sentinel down-after-milliseconds {self.default_deployment} 3000", + f"slave-priority 100", ] self.config_file.write_text("\n".join(config)) @@ -228,7 +230,7 @@ async def test_master_failure(df_local_factory, sentinel, port_picker): # Simulate master failure. master.stop() - # Verify replica pormoted. + # Verify replica promoted. await await_for( lambda: sentinel.live_master_port(), lambda p: p == replica.port, @@ -239,3 +241,54 @@ async def test_master_failure(df_local_factory, sentinel, port_picker): # Verify we can now write to replica. await replica_client.set("key", "value") assert await replica_client.get("key") == b"value" + + +@dfly_args({"info_replication_valkey_compatible": True}) +@pytest.mark.asyncio +async def test_priority_on_failover(df_local_factory, sentinel, port_picker): + master = df_local_factory.create(port=sentinel.initial_master_port) + # lower priority is the best candidate for sentinel + low_priority_repl = df_local_factory.create( + port=port_picker.get_available_port(), replica_priority=20 + ) + mid_priority_repl = df_local_factory.create( + port=port_picker.get_available_port(), replica_priority=60 + ) + high_priority_repl = df_local_factory.create( + port=port_picker.get_available_port(), replica_priority=80 + ) + + master.start() + low_priority_repl.start() + mid_priority_repl.start() + high_priority_repl.start() + + high_client = aioredis.Redis(port=high_priority_repl.port) + await high_client.execute_command("REPLICAOF localhost " + str(master.port)) + + mid_client = aioredis.Redis(port=mid_priority_repl.port) + await mid_client.execute_command("REPLICAOF localhost " + str(master.port)) + + low_client = aioredis.Redis(port=low_priority_repl.port) + await low_client.execute_command("REPLICAOF localhost " + str(master.port)) + + assert sentinel.live_master_port() == master.port + + # Verify sentinel picked up replica. + await await_for( + lambda: sentinel.master(), + lambda m: m["num-slaves"] == "3", + timeout_sec=15, + timeout_msg="Timeout waiting for sentinel to pick up replica.", + ) + + # Simulate master failure. + master.stop() + + # Verify replica promoted. + await await_for( + lambda: sentinel.live_master_port(), + lambda p: p == low_priority_repl.port, + timeout_sec=30, + timeout_msg="Timeout waiting for sentinel to report replica as master.", + ) diff --git a/tests/dragonfly/tiering_test.py b/tests/dragonfly/tiering_test.py index 6fb4baaa46e4..fec382506542 100644 --- a/tests/dragonfly/tiering_test.py +++ b/tests/dragonfly/tiering_test.py @@ -1,39 +1,81 @@ -from . import dfly_args - import async_timeout import asyncio +import itertools +import pytest +import random import redis.asyncio as aioredis -BASIC_ARGS = {"port": 6379, "proactor_threads": 1, "tiered_prefix": "/tmp/tiering_test_backing"} +from . import dfly_args +from .seeder import StaticSeeder + + +BASIC_ARGS = {"port": 6379, "proactor_threads": 4, "tiered_prefix": "/tmp/tiering_test_backing"} -# remove once proudct requirments are tested +@pytest.mark.opt_only @dfly_args(BASIC_ARGS) -async def test_tiering_simple(async_client: aioredis.Redis): - fill_script = """#!lua flags=disable-atomicity - for i = 1, 100 do - redis.call('SET', 'k' .. i, string.rep('a', 3000)) - end +async def test_basic_memory_usage(async_client: aioredis.Redis): """ + Loading 1GB of mixed size strings (256b-16kb) will keep most of them on disk and thus RAM remains almost unused + """ + + seeder = StaticSeeder( + key_target=200_000, data_size=2048, variance=8, samples=100, types=["STRING"] + ) + await seeder.run(async_client) + await asyncio.sleep(0.5) + + info = await async_client.info("ALL") + assert info["num_entries"] == 200_000 + + assert info["tiered_entries"] > 195_000 # some remain in unfilled small bins + assert ( + info["tiered_allocated_bytes"] > 195_000 * 2048 * 0.8 + ) # 0.8 just to be sure because it fluctuates due to variance + + assert info["used_memory"] < 50 * 1024 * 1024 + assert ( + info["used_memory_rss"] < 500 * 1024 * 1024 + ) # the grown table itself takes up lots of space + + +@pytest.mark.opt_only +@dfly_args( + { + **BASIC_ARGS, + "maxmemory": "1G", + "tiered_offload_threshold": "0.0", + "tiered_storage_write_depth": 1000, + } +) +async def test_mixed_append(async_client: aioredis.Redis): + """ + Issue conflicting mixed APPEND calls for a limited subset of keys with aggressive offloading in the background. + Make sure no appends were lost + """ + + # Generate operations and shuffle them, key number `k` will have `k` append operations + key_range = list(range(100, 300)) + ops = list(itertools.chain(*map(lambda k: itertools.repeat(k, k), key_range))) + random.shuffle(ops) + + # Split list into n workers and run it + async def run(sub_ops): + p = async_client.pipeline(transaction=False) + for k in sub_ops: + p.append(f"k{k}", 10 * "x") + await p.execute() + + n = 20 + await asyncio.gather(*(run(ops[i::n]) for i in range(n))) + + info = await async_client.info("tiered") + assert info["tiered_entries"] > len(key_range) / 5 + + # Verify lengths + p = async_client.pipeline(transaction=False) + for k in key_range: + p.strlen(f"k{k}") + res = await p.execute() - # Store 100 entries - await async_client.eval(fill_script, 0) - - # Wait for all to be offloaded - with async_timeout.timeout(1): - info = await async_client.info("TIERED") - print(info) - while info["tiered_total_stashes"] != 100: - info = await async_client.info("TIERED") - await asyncio.sleep(0.1) - assert 3000 * 100 <= info["tiered_allocated_bytes"] <= 4096 * 100 - - # Fetch back - for key in (f"k{i}" for i in range(1, 100 + 1)): - assert len(await async_client.execute_command("GET", key)) == 3000 - assert (await async_client.info("TIERED"))["tiered_total_fetches"] == 100 - - # Wait to be deleted - with async_timeout.timeout(1): - while (await async_client.info("TIERED"))["tiered_allocated_bytes"] > 0: - await asyncio.sleep(0.1) + assert res == [10 * k for k in key_range]