From c7581de4f2a3202ea8f305d3cd50190a5a7aa67f Mon Sep 17 00:00:00 2001 From: Alexis Olson Date: Wed, 22 Oct 2025 22:58:47 -0500 Subject: [PATCH] Initial commit --- src/search/classic/node.h | 18 +++++++++++++ src/search/classic/params.cc | 9 ++++++- src/search/classic/params.h | 3 +++ src/search/classic/search.cc | 52 +++++++++++++++++++++++++++++++----- 4 files changed, 75 insertions(+), 7 deletions(-) diff --git a/src/search/classic/node.h b/src/search/classic/node.h index 8a4e598fdb..5605cb4d53 100644 --- a/src/search/classic/node.h +++ b/src/search/classic/node.h @@ -96,6 +96,16 @@ class Edge { float GetP() const; void SetP(float val); + // Virtual visits and values for node priors, initialized once at expansion. + float GetN0() const { return n0_; } + float GetW0() const { return w0_; } + float GetD0() const { return d0_; } + void SetVirtualVisits(float n0, float w0, float d0) { + n0_ = n0; + w0_ = w0; + d0_ = d0; + } + // Debug information about the edge. std::string DebugString() const; @@ -108,6 +118,14 @@ class Edge { // Probability that this move will be made, from the policy head of the neural // network; compressed to a 16 bit format (5 bits exp, 11 bits significand). uint16_t p_ = 0; + + // Virtual visits for node priors: n0(a) = K*P(a), initialized once at expansion. + float n0_ = 0.0f; + // Virtual W value: W0(a) = n0(a) * WL_parent. + float w0_ = 0.0f; + // Virtual D value: D0(a) = n0(a) * D_parent. + float d0_ = 0.0f; + friend class Node; }; diff --git a/src/search/classic/params.cc b/src/search/classic/params.cc index e61b0f9c88..a9df16cf7e 100644 --- a/src/search/classic/params.cc +++ b/src/search/classic/params.cc @@ -529,6 +529,11 @@ const OptionId BaseSearchParams::kGarbageCollectionDelayId{ "garbage-collection-delay", "GarbageCollectionDelay", "The percentage of expected move time until garbage collection start. " "Delay lets search find transpositions to freed search tree branches."}; +const OptionId BaseSearchParams::kNodePriorId{ + "node-prior", "NodePrior", + "Strength of node priors. At expansion, each edge is initialized with " + "n0(a) = K*P(a) virtual visits, where K = alpha * num_legal_moves. " + "Set to 0 to disable node priors."}; const OptionId SearchParams::kMaxPrefetchBatchId{ "max-prefetch", "MaxPrefetch", @@ -631,6 +636,7 @@ void BaseSearchParams::Populate(OptionsParser* options) { options->Add(kUCIRatingAdvId, -10000.0f, 10000.0f) = 0.0f; options->Add(kSearchSpinBackoffId) = false; options->Add(kGarbageCollectionDelayId, 0.0f, 100.0f) = 10.0f; + options->Add(kNodePriorId, 0.0f, 10.0f) = 0.0f; } void SearchParams::Populate(OptionsParser* options) { @@ -725,7 +731,8 @@ BaseSearchParams::BaseSearchParams(const OptionsDict& options) kMaxCollisionVisitsScalingPower( options.Get(kMaxCollisionVisitsScalingPowerId)), kSearchSpinBackoff(options_.Get(kSearchSpinBackoffId)), - kGarbageCollectionDelay(options_.Get(kGarbageCollectionDelayId)) {} + kGarbageCollectionDelay(options_.Get(kGarbageCollectionDelayId)), + kNodePrior(options_.Get(kNodePriorId)) {} SearchParams::SearchParams(const OptionsDict& options) : BaseSearchParams(options), diff --git a/src/search/classic/params.h b/src/search/classic/params.h index d84dbad5d8..da62ee1db2 100644 --- a/src/search/classic/params.h +++ b/src/search/classic/params.h @@ -162,6 +162,7 @@ class BaseSearchParams { float GetGarbageCollectionDelay() const { return kGarbageCollectionDelay; } + float GetNodePrior() const { return kNodePrior; } // Search parameter IDs. static const OptionId kMiniBatchSizeId; @@ -231,6 +232,7 @@ class BaseSearchParams { static const OptionId kUCIRatingAdvId; static const OptionId kSearchSpinBackoffId; static const OptionId kGarbageCollectionDelayId; + static const OptionId kNodePriorId; protected: const OptionsDict& options_; @@ -290,6 +292,7 @@ class BaseSearchParams { const float kMaxCollisionVisitsScalingPower; const bool kSearchSpinBackoff; const float kGarbageCollectionDelay; + const float kNodePrior; }; class SearchParams : public BaseSearchParams { diff --git a/src/search/classic/search.cc b/src/search/classic/search.cc index 101d62e941..2e42bbcd92 100644 --- a/src/search/classic/search.cc +++ b/src/search/classic/search.cc @@ -1696,7 +1696,16 @@ void SearchWorker::PickNodesToExtendTask( for (Node* child : node->VisitedNodes()) { int index = child->Index(); visited_pol += current_pol[index]; - float q = child->GetQ(draw_score); + // Overlay virtual visits for Q calculation. + // Convert averages to totals before adding. + Edge* edge = node->GetEdgeToNode(child); + const float real_n = static_cast(child->GetN()); + const float real_wl = child->GetWL() * real_n; + const float real_d = child->GetD() * real_n; + const float total_wl = real_wl + edge->GetW0(); + const float total_d = real_d + edge->GetD0(); + const float total_n = real_n + edge->GetN0(); + float q = (total_n > 0.0f) ? (total_wl + draw_score * total_d) / total_n : 0.0f; current_util[index] = q + m_evaluator.GetMUtility(child, q); } const float fpu = @@ -1732,8 +1741,9 @@ void SearchWorker::PickNodesToExtendTask( int nstarted = current_nstarted[idx]; const float util = current_util[idx]; if (idx > cache_filled_idx) { + const float virtual_n = cur_iters[idx].edge()->GetN0(); current_score[idx] = - current_pol[idx] * puct_mult / (1 + nstarted) + util; + current_pol[idx] * puct_mult / (1.0f + virtual_n + nstarted) + util; cache_filled_idx++; } if (is_root_node) { @@ -1779,7 +1789,8 @@ void SearchWorker::PickNodesToExtendTask( if (second_best_edge) { int estimated_visits_to_change_best = std::numeric_limits::max(); if (best_without_u < second_best) { - const auto n1 = current_nstarted[best_idx] + 1; + const float virtual_n = cur_iters[best_idx].edge()->GetN0(); + const auto n1 = current_nstarted[best_idx] + 1 + virtual_n; estimated_visits_to_change_best = static_cast( std::max(1.0f, std::min(current_pol[best_idx] * puct_mult / (second_best - best_without_u) - @@ -1816,8 +1827,9 @@ void SearchWorker::PickNodesToExtendTask( child_node->IncrementNInFlight(new_visits); current_nstarted[best_idx] += new_visits; } + const float virtual_n = cur_iters[best_idx].edge()->GetN0(); current_score[best_idx] = current_pol[best_idx] * puct_mult / - (1 + current_nstarted[best_idx]) + + (1.0f + virtual_n + current_nstarted[best_idx]) + current_util[best_idx]; } if ((decremented && @@ -2076,10 +2088,24 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget, bool is_odd_depth) { GetFpu(params_, node, node == search_->root_node_, draw_score); for (auto& edge : node->Edges()) { if (edge.GetP() == 0.0f) continue; + // Compute Q with virtual visits overlay. + // Convert averages to totals before adding. + float q = edge.GetQ(fpu, draw_score); + if (edge.node() && edge.GetN() > 0) { + const float real_n = static_cast(edge.node()->GetN()); + const float real_wl = edge.node()->GetWL() * real_n; + const float real_d = edge.node()->GetD() * real_n; + const float total_wl = real_wl + edge.edge()->GetW0(); + const float total_d = real_d + edge.edge()->GetD0(); + const float total_n = real_n + edge.edge()->GetN0(); + q = (total_n > 0.0f) ? (total_wl + draw_score * total_d) / total_n : fpu; + } + // Compute U with virtual visits in denominator. + float n_effective = edge.GetNStarted() + edge.edge()->GetN0(); + float u = puct_mult * edge.GetP() / (1.0f + n_effective); // Flip the sign of a score to be able to easily sort. // TODO: should this use logit_q if set?? - scores.emplace_back(-edge.GetU(puct_mult) - edge.GetQ(fpu, draw_score), - edge); + scores.emplace_back(-u - q, edge); } size_t first_unsorted_index = 0; @@ -2181,6 +2207,20 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process) { ApplyDirichletNoise(node, params_.GetNoiseEpsilon(), params_.GetNoiseAlpha()); } + // Initialize virtual visits for node priors. + const float alpha = params_.GetNodePrior(); + if (alpha > 0.0f) { + const float K = alpha * node->GetNumEdges(); + // Use NN eval for parent values (most robust - always available). + const float parent_wl = node_to_process->eval->q; + const float parent_d = node_to_process->eval->d; + for (auto& edge : node->Edges()) { + const float n0 = K * edge.GetP(); + const float w0 = n0 * parent_wl; + const float d0 = n0 * parent_d; + edge.edge()->SetVirtualVisits(n0, w0, d0); + } + } node->SortEdges(); }