Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/search/classic/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ class Edge {
float GetP() const;
void SetP(float val);

// Virtual visits and values for node priors, initialized once at expansion.
float GetN0() const { return n0_; }
float GetW0() const { return w0_; }
float GetD0() const { return d0_; }
void SetVirtualVisits(float n0, float w0, float d0) {
n0_ = n0;
w0_ = w0;
d0_ = d0;
}

// Debug information about the edge.
std::string DebugString() const;

Expand All @@ -108,6 +118,14 @@ class Edge {
// Probability that this move will be made, from the policy head of the neural
// network; compressed to a 16 bit format (5 bits exp, 11 bits significand).
uint16_t p_ = 0;

// Virtual visits for node priors: n0(a) = K*P(a), initialized once at expansion.
float n0_ = 0.0f;
// Virtual W value: W0(a) = n0(a) * WL_parent.
float w0_ = 0.0f;
// Virtual D value: D0(a) = n0(a) * D_parent.
float d0_ = 0.0f;

friend class Node;
};

Expand Down
9 changes: 8 additions & 1 deletion src/search/classic/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ const OptionId BaseSearchParams::kGarbageCollectionDelayId{
"garbage-collection-delay", "GarbageCollectionDelay",
"The percentage of expected move time until garbage collection start. "
"Delay lets search find transpositions to freed search tree branches."};
const OptionId BaseSearchParams::kNodePriorId{
"node-prior", "NodePrior",
"Strength of node priors. At expansion, each edge is initialized with "
"n0(a) = K*P(a) virtual visits, where K = alpha * num_legal_moves. "
"Set to 0 to disable node priors."};

const OptionId SearchParams::kMaxPrefetchBatchId{
"max-prefetch", "MaxPrefetch",
Expand Down Expand Up @@ -631,6 +636,7 @@ void BaseSearchParams::Populate(OptionsParser* options) {
options->Add<FloatOption>(kUCIRatingAdvId, -10000.0f, 10000.0f) = 0.0f;
options->Add<BoolOption>(kSearchSpinBackoffId) = false;
options->Add<FloatOption>(kGarbageCollectionDelayId, 0.0f, 100.0f) = 10.0f;
options->Add<FloatOption>(kNodePriorId, 0.0f, 10.0f) = 0.0f;
}

void SearchParams::Populate(OptionsParser* options) {
Expand Down Expand Up @@ -725,7 +731,8 @@ BaseSearchParams::BaseSearchParams(const OptionsDict& options)
kMaxCollisionVisitsScalingPower(
options.Get<float>(kMaxCollisionVisitsScalingPowerId)),
kSearchSpinBackoff(options_.Get<bool>(kSearchSpinBackoffId)),
kGarbageCollectionDelay(options_.Get<float>(kGarbageCollectionDelayId)) {}
kGarbageCollectionDelay(options_.Get<float>(kGarbageCollectionDelayId)),
kNodePrior(options_.Get<float>(kNodePriorId)) {}

SearchParams::SearchParams(const OptionsDict& options)
: BaseSearchParams(options),
Expand Down
3 changes: 3 additions & 0 deletions src/search/classic/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class BaseSearchParams {
float GetGarbageCollectionDelay() const {
return kGarbageCollectionDelay;
}
float GetNodePrior() const { return kNodePrior; }

// Search parameter IDs.
static const OptionId kMiniBatchSizeId;
Expand Down Expand Up @@ -231,6 +232,7 @@ class BaseSearchParams {
static const OptionId kUCIRatingAdvId;
static const OptionId kSearchSpinBackoffId;
static const OptionId kGarbageCollectionDelayId;
static const OptionId kNodePriorId;

protected:
const OptionsDict& options_;
Expand Down Expand Up @@ -290,6 +292,7 @@ class BaseSearchParams {
const float kMaxCollisionVisitsScalingPower;
const bool kSearchSpinBackoff;
const float kGarbageCollectionDelay;
const float kNodePrior;
};

class SearchParams : public BaseSearchParams {
Expand Down
52 changes: 46 additions & 6 deletions src/search/classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,16 @@ void SearchWorker::PickNodesToExtendTask(
for (Node* child : node->VisitedNodes()) {
int index = child->Index();
visited_pol += current_pol[index];
float q = child->GetQ(draw_score);
// Overlay virtual visits for Q calculation.
// Convert averages to totals before adding.
Edge* edge = node->GetEdgeToNode(child);
const float real_n = static_cast<float>(child->GetN());
const float real_wl = child->GetWL() * real_n;
const float real_d = child->GetD() * real_n;
const float total_wl = real_wl + edge->GetW0();
const float total_d = real_d + edge->GetD0();
const float total_n = real_n + edge->GetN0();
float q = (total_n > 0.0f) ? (total_wl + draw_score * total_d) / total_n : 0.0f;
current_util[index] = q + m_evaluator.GetMUtility(child, q);
}
const float fpu =
Expand Down Expand Up @@ -1732,8 +1741,9 @@ void SearchWorker::PickNodesToExtendTask(
int nstarted = current_nstarted[idx];
const float util = current_util[idx];
if (idx > cache_filled_idx) {
const float virtual_n = cur_iters[idx].edge()->GetN0();
current_score[idx] =
current_pol[idx] * puct_mult / (1 + nstarted) + util;
current_pol[idx] * puct_mult / (1.0f + virtual_n + nstarted) + util;
cache_filled_idx++;
}
if (is_root_node) {
Expand Down Expand Up @@ -1779,7 +1789,8 @@ void SearchWorker::PickNodesToExtendTask(
if (second_best_edge) {
int estimated_visits_to_change_best = std::numeric_limits<int>::max();
if (best_without_u < second_best) {
const auto n1 = current_nstarted[best_idx] + 1;
const float virtual_n = cur_iters[best_idx].edge()->GetN0();
const auto n1 = current_nstarted[best_idx] + 1 + virtual_n;
estimated_visits_to_change_best = static_cast<int>(
std::max(1.0f, std::min(current_pol[best_idx] * puct_mult /
(second_best - best_without_u) -
Expand Down Expand Up @@ -1816,8 +1827,9 @@ void SearchWorker::PickNodesToExtendTask(
child_node->IncrementNInFlight(new_visits);
current_nstarted[best_idx] += new_visits;
}
const float virtual_n = cur_iters[best_idx].edge()->GetN0();
current_score[best_idx] = current_pol[best_idx] * puct_mult /
(1 + current_nstarted[best_idx]) +
(1.0f + virtual_n + current_nstarted[best_idx]) +
current_util[best_idx];
}
if ((decremented &&
Expand Down Expand Up @@ -2076,10 +2088,24 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget, bool is_odd_depth) {
GetFpu(params_, node, node == search_->root_node_, draw_score);
for (auto& edge : node->Edges()) {
if (edge.GetP() == 0.0f) continue;
// Compute Q with virtual visits overlay.
// Convert averages to totals before adding.
float q = edge.GetQ(fpu, draw_score);
if (edge.node() && edge.GetN() > 0) {
const float real_n = static_cast<float>(edge.node()->GetN());
const float real_wl = edge.node()->GetWL() * real_n;
const float real_d = edge.node()->GetD() * real_n;
const float total_wl = real_wl + edge.edge()->GetW0();
const float total_d = real_d + edge.edge()->GetD0();
const float total_n = real_n + edge.edge()->GetN0();
q = (total_n > 0.0f) ? (total_wl + draw_score * total_d) / total_n : fpu;
}
// Compute U with virtual visits in denominator.
float n_effective = edge.GetNStarted() + edge.edge()->GetN0();
float u = puct_mult * edge.GetP() / (1.0f + n_effective);
// Flip the sign of a score to be able to easily sort.
// TODO: should this use logit_q if set??
scores.emplace_back(-edge.GetU(puct_mult) - edge.GetQ(fpu, draw_score),
edge);
scores.emplace_back(-u - q, edge);
}

size_t first_unsorted_index = 0;
Expand Down Expand Up @@ -2181,6 +2207,20 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process) {
ApplyDirichletNoise(node, params_.GetNoiseEpsilon(),
params_.GetNoiseAlpha());
}
// Initialize virtual visits for node priors.
const float alpha = params_.GetNodePrior();
if (alpha > 0.0f) {
const float K = alpha * node->GetNumEdges();
// Use NN eval for parent values (most robust - always available).
const float parent_wl = node_to_process->eval->q;
const float parent_d = node_to_process->eval->d;
for (auto& edge : node->Edges()) {
const float n0 = K * edge.GetP();
const float w0 = n0 * parent_wl;
const float d0 = n0 * parent_d;
edge.edge()->SetVirtualVisits(n0, w0, d0);
}
}
node->SortEdges();
}

Expand Down