Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/program/setup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,11 @@ vector<SearchParams> Setup::loadParams(
else if(cfg.contains("humanSLOppExploreProbWeightful"+idxStr)) params.humanSLOppExploreProbWeightful = cfg.getDouble("humanSLOppExploreProbWeightful"+idxStr, 0.0, 1.0);
else if(cfg.contains("humanSLOppExploreProbWeightful")) params.humanSLOppExploreProbWeightful = cfg.getDouble("humanSLOppExploreProbWeightful", 0.0, 1.0);
else params.humanSLOppExploreProbWeightful = 0.0;
if(!hasHumanModel && cfg.contains("humanSLValueProportion"+idxStr)) throwHumanParsingError("humanSLValueProportion"+idxStr);
else if(!hasHumanModel && cfg.contains("humanSLValueProportion")) throwHumanParsingError("humanSLValueProportion");
else if(cfg.contains("humanSLValueProportion"+idxStr)) params.humanSLValueProportion = cfg.getDouble("humanSLValueProportion"+idxStr, 0.0, 1.0);
else if(cfg.contains("humanSLValueProportion")) params.humanSLValueProportion = cfg.getDouble("humanSLValueProportion", 0.0, 1.0);
else params.humanSLValueProportion = 0.0;
if(!hasHumanModel && cfg.contains("humanSLChosenMoveProp"+idxStr)) throwHumanParsingError("humanSLChosenMoveProp"+idxStr);
else if(!hasHumanModel && cfg.contains("humanSLChosenMoveProp")) throwHumanParsingError("humanSLChosenMoveProp");
else if(cfg.contains("humanSLChosenMoveProp"+idxStr)) params.humanSLChosenMoveProp = cfg.getDouble("humanSLChosenMoveProp"+idxStr, 0.0, 1.0);
Expand Down
7 changes: 6 additions & 1 deletion cpp/search/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,9 +1068,14 @@ void Search::computeRootValues() {
//Grab a neural net evaluation for the current position and use that as the center
if(!foundExpectedScoreFromTree) {
NNResultBuf nnResultBuf;
NNResultBuf humanResultBuf;
bool includeOwnerMap = true;
computeRootNNEvaluation(nnResultBuf,includeOwnerMap);
bool includeHumanResult = humanEvaluator != NULL && searchParams.humanSLValueProportion > 0;
computeRootNNEvaluation(nnResultBuf,humanResultBuf,includeOwnerMap,includeHumanResult);
expectedScore = nnResultBuf.result->whiteScoreMean;
if(includeHumanResult) {
expectedScore += searchParams.humanSLValueProportion * ((double)(humanResultBuf.result->whiteScoreMean) - expectedScore);
}
}

recentScoreCenter = expectedScore * (1.0 - searchParams.dynamicScoreCenterZeroWeight);
Expand Down
5 changes: 2 additions & 3 deletions cpp/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,9 @@ struct Search {
// searchhelpers.cpp
//----------------------------------------------------------------------------------------
double getResultUtility(double winlossValue, double noResultValue) const;
double getResultUtilityFromNN(const NNOutput& nnOutput) const;
double getScoreUtility(double scoreMeanAvg, double scoreMeanSqAvg) const;
double getScoreUtilityDiff(double scoreMeanAvg, double scoreMeanSqAvg, double delta) const;
double getApproxScoreUtilityDerivative(double scoreMean) const;
double getUtilityFromNN(const NNOutput& nnOutput) const;

//----------------------------------------------------------------------------------------
// Miscellaneous search biasing helpers, root move selection, etc.
Expand Down Expand Up @@ -517,7 +515,7 @@ struct Search {
// Neural net queries
// searchnnhelpers.cpp
//----------------------------------------------------------------------------------------
void computeRootNNEvaluation(NNResultBuf& nnResultBuf, bool includeOwnerMap);
void computeRootNNEvaluation(NNResultBuf& nnResultBuf, NNResultBuf& humanResultBuf, bool includeOwnerMap, bool includeHumanResult);
bool initNodeNNOutput(
SearchThread& thread, SearchNode& node,
bool isRoot, bool skipCache, bool isReInit
Expand Down Expand Up @@ -610,6 +608,7 @@ struct Search {
bool assumeNoExistingWeight
);
void addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExistingWeight);
double getThisNodeNNUtility(const SearchNode& node) const;

double computeWeightFromNNOutput(const NNOutput* nnOutput) const;

Expand Down
4 changes: 2 additions & 2 deletions cpp/search/searchexplorehelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ double Search::getFpuValueForChildrenAssumeVisited(
double parentUtilityForFPU = parentUtility;
if(searchParams.fpuParentWeightByVisitedPolicy) {
double avgWeight = std::min(1.0, pow(policyProbMassVisited, searchParams.fpuParentWeightByVisitedPolicyPow));
parentUtilityForFPU = avgWeight * parentUtility + (1.0 - avgWeight) * getUtilityFromNN(*(node.getNNOutput()));
parentUtilityForFPU = avgWeight * parentUtility + (1.0 - avgWeight) * getThisNodeNNUtility(node);
}
else if(searchParams.fpuParentWeight > 0.0) {
parentUtilityForFPU = searchParams.fpuParentWeight * getUtilityFromNN(*(node.getNNOutput())) + (1.0 - searchParams.fpuParentWeight) * parentUtility;
parentUtilityForFPU = searchParams.fpuParentWeight * getThisNodeNNUtility(node) + (1.0 - searchParams.fpuParentWeight) * parentUtility;
}

double fpuValue;
Expand Down
13 changes: 0 additions & 13 deletions cpp/search/searchhelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,6 @@ double Search::getResultUtility(double winLossValue, double noResultValue) const
);
}

double Search::getResultUtilityFromNN(const NNOutput& nnOutput) const {
return (
(nnOutput.whiteWinProb - nnOutput.whiteLossProb) * searchParams.winLossUtilityFactor +
nnOutput.whiteNoResultProb * searchParams.noResultUtilityForWhite
);
}

double Search::getScoreUtility(double scoreMeanAvg, double scoreMeanSqAvg) const {
double scoreMean = scoreMeanAvg;
double scoreMeanSq = scoreMeanSqAvg;
Expand Down Expand Up @@ -301,12 +294,6 @@ double Search::getApproxScoreUtilityDerivative(double scoreMean) const {
}


double Search::getUtilityFromNN(const NNOutput& nnOutput) const {
double resultUtility = getResultUtilityFromNN(nnOutput);
return resultUtility + getScoreUtility(nnOutput.whiteScoreMean, nnOutput.whiteScoreMeanSq);
}


bool Search::isAllowedRootMove(Loc moveLoc) const {
assert(moveLoc == Board::PASS_LOC || rootBoard.isOnBoard(moveLoc));

Expand Down
14 changes: 12 additions & 2 deletions cpp/search/searchnnhelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "../core/using.h"
//------------------------

void Search::computeRootNNEvaluation(NNResultBuf& nnResultBuf, bool includeOwnerMap) {
void Search::computeRootNNEvaluation(NNResultBuf& nnResultBuf, NNResultBuf& humanResultBuf, bool includeOwnerMap, bool includeHumanResult) {
Board board = rootBoard;
const BoardHistory& hist = rootHistory;
Player pla = rootPla;
Expand All @@ -32,6 +32,15 @@ void Search::computeRootNNEvaluation(NNResultBuf& nnResultBuf, bool includeOwner
nnInputParams,
nnResultBuf, skipCache, includeOwnerMap
);

if(includeHumanResult) {
assert(humanEvaluator != NULL);
humanEvaluator->evaluate(
board, hist, pla, &searchParams.humanSLProfile,
nnInputParams,
humanResultBuf, skipCache, includeOwnerMap
);
}
}

bool Search::needsHumanOutputAtRoot() const {
Expand All @@ -42,7 +51,8 @@ bool Search::needsHumanOutputInTree() const {
searchParams.humanSLPlaExploreProbWeightless > 0 ||
searchParams.humanSLPlaExploreProbWeightful > 0 ||
searchParams.humanSLOppExploreProbWeightless > 0 ||
searchParams.humanSLOppExploreProbWeightful > 0
searchParams.humanSLOppExploreProbWeightful > 0 ||
searchParams.humanSLValueProportion > 0
);
}

Expand Down
4 changes: 4 additions & 0 deletions cpp/search/searchparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ SearchParams::SearchParams()
humanSLPlaExploreProbWeightful(0.0),
humanSLOppExploreProbWeightless(0.0),
humanSLOppExploreProbWeightful(0.0),
humanSLValueProportion(0.0),
humanSLChosenMoveProp(0.0),
humanSLChosenMoveIgnorePass(false),
humanSLChosenMovePiklLambda(1000000000.0)
Expand Down Expand Up @@ -252,6 +253,7 @@ bool SearchParams::operator==(const SearchParams& other) const {
humanSLOppExploreProbWeightless == other.humanSLOppExploreProbWeightless &&
humanSLOppExploreProbWeightful == other.humanSLOppExploreProbWeightful &&

humanSLValueProportion == other.humanSLValueProportion &&
humanSLChosenMoveProp == other.humanSLChosenMoveProp &&
humanSLChosenMoveIgnorePass == other.humanSLChosenMoveIgnorePass &&
humanSLChosenMovePiklLambda == other.humanSLChosenMovePiklLambda
Expand Down Expand Up @@ -499,6 +501,7 @@ json SearchParams::changeableParametersToJson() const {
ret["humanSLOppExploreProbWeightless"] = humanSLOppExploreProbWeightless;
ret["humanSLOppExploreProbWeightful"] = humanSLOppExploreProbWeightful;

ret["humanSLValueProportion"] = humanSLValueProportion;
ret["humanSLChosenMoveProp"] = humanSLChosenMoveProp;
ret["humanSLChosenMoveIgnorePass"] = humanSLChosenMoveIgnorePass;
ret["humanSLChosenMovePiklLambda"] = humanSLChosenMovePiklLambda;
Expand Down Expand Up @@ -650,6 +653,7 @@ void SearchParams::printParams(std::ostream& out) const {
PRINTPARAM(humanSLPlaExploreProbWeightful);
PRINTPARAM(humanSLOppExploreProbWeightless);
PRINTPARAM(humanSLOppExploreProbWeightful);
PRINTPARAM(humanSLValueProportion);
PRINTPARAM(humanSLChosenMoveProp);
PRINTPARAM(humanSLChosenMoveIgnorePass);
PRINTPARAM(humanSLChosenMovePiklLambda);
Expand Down
3 changes: 3 additions & 0 deletions cpp/search/searchparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ struct SearchParams {
double humanSLOppExploreProbWeightless;
double humanSLOppExploreProbWeightful;

//Mix in this amount of the humanSL value into the values at nodes
double humanSLValueProportion;

//These three are PRIOR to the normal chosenMoveTemperature.
double humanSLChosenMoveProp; //Proportion of final move selection probability using human SL policy
bool humanSLChosenMoveIgnorePass; //If true, ignore human SL pass probability and use KataGo's passing logic
Expand Down
10 changes: 10 additions & 0 deletions cpp/search/searchresults.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,16 @@ bool Search::getPrunedNodeValues(const SearchNode* nodePtr, ReportedSearchValues
double scoreMean = (double)nnOutput->whiteScoreMean;
double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq;
double lead = (double)nnOutput->whiteLead;
if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) {
const NNOutput* humanOutput = node.getHumanOutput();
assert(humanOutput != NULL);
winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb);
lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb);
noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb);
scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean);
scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq);
lead += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLead) - lead);
}
double utility =
getResultUtility(winProb-lossProb, noResultProb)
+ getScoreUtility(scoreMean, scoreMeanSq);
Expand Down
44 changes: 44 additions & 0 deletions cpp/search/searchupdatehelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ void Search::addLeafValue(
}
}

double Search::getThisNodeNNUtility(const SearchNode& node) const {
const NNOutput* nnOutput = node.getNNOutput();
assert(nnOutput != NULL);
double winProb = (double)nnOutput->whiteWinProb;
double lossProb = (double)nnOutput->whiteLossProb;
double noResultProb = (double)nnOutput->whiteNoResultProb;
double scoreMean = (double)nnOutput->whiteScoreMean;
double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq;
if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) {
const NNOutput* humanOutput = node.getHumanOutput();
assert(humanOutput != NULL);
winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb);
lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb);
noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb);
scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean);
scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq);
}
double utility =
getResultUtility(winProb-lossProb, noResultProb)
+ getScoreUtility(scoreMean, scoreMeanSq);
return utility;
}

void Search::addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExistingWeight) {
const NNOutput* nnOutput = node.getNNOutput();
assert(nnOutput != NULL);
Expand All @@ -92,6 +115,16 @@ void Search::addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExisti
double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq;
double lead = (double)nnOutput->whiteLead;
double weight = computeWeightFromNNOutput(nnOutput);
if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) {
const NNOutput* humanOutput = node.getHumanOutput();
assert(humanOutput != NULL);
winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb);
lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb);
noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb);
scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean);
scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq);
lead += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLead) - lead);
}
addLeafValue(node,winProb-lossProb,noResultProb,scoreMean,scoreMeanSq,lead,weight,false,assumeNoExistingWeight);
}

Expand Down Expand Up @@ -248,6 +281,17 @@ void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, int numV
double scoreMean = (double)nnOutput->whiteScoreMean;
double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq;
double lead = (double)nnOutput->whiteLead;
if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) {
const NNOutput* humanOutput = node.getHumanOutput();
assert(humanOutput != NULL);
winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb);
lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb);
noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb);
scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean);
scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq);
lead += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLead) - lead);
}

double utility =
getResultUtility(winProb-lossProb, noResultProb)
+ getScoreUtility(scoreMean, scoreMeanSq);
Expand Down
8 changes: 8 additions & 0 deletions cpp/tests/results/runOutputTests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20666,6 +20666,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -20773,6 +20774,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -20880,6 +20882,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -20987,6 +20990,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -21094,6 +21098,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -21201,6 +21206,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -21308,6 +21314,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down Expand Up @@ -21415,6 +21422,7 @@ humanSLPlaExploreProbWeightless: 0
humanSLPlaExploreProbWeightful: 0
humanSLOppExploreProbWeightless: 0
humanSLOppExploreProbWeightful: 0
humanSLValueProportion: 0
humanSLChosenMoveProp: 0
humanSLChosenMoveIgnorePass: 0
humanSLChosenMovePiklLambda: 1e+09
Expand Down