Skip to content

Commit e653379

Browse files
cburgardPhmonski
authored andcommitted
test-driven-development: adding a few tests and make sure they succeed
1 parent 2fd2bcc commit e653379

File tree

3 files changed

+185
-75
lines changed

3 files changed

+185
-75
lines changed

roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,10 @@ class RooJSONFactoryWSTool {
233233
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n);
234234

235235
void exportAllObjects(RooFit::Detail::JSONNode &n);
236-
236+
237237
void exportModelConfig(RooFit::Detail::JSONNode &rootnode, RooStats::ModelConfig const &mc,
238-
const std::vector<RooJSONFactoryWSTool::CombinedData> &d);
238+
const std::vector<RooJSONFactoryWSTool::CombinedData> &combined,
239+
const std::vector<RooAbsData*> &single);
239240

240241
void exportSingleModelConfig(RooFit::Detail::JSONNode &rootnode, RooStats::ModelConfig const &mc,
241242
std::string const &analysisName,

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 67 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons
636636
observables.add(*d->get());
637637
}
638638
}
639-
if(!found) throw std::runtime_error("dataset '"+nameNode.val()+"' cannot be found!");
639+
if(nameNode.val() != "0" && !found) throw std::runtime_error("dataset '"+nameNode.val()+"' cannot be found!");
640640
}
641641

642642
JSONNode const *pdfNameNode = mcAuxNode ? mcAuxNode->find("pdfName") : nullptr;
@@ -1051,6 +1051,7 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
10511051
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n)
10521052
{
10531053
// export a list of RooRealVar objects
1054+
n.set_seq();
10541055
for (RooAbsArg *arg : allElems) {
10551056
exportVariable(arg, n);
10561057
}
@@ -1357,9 +1358,9 @@ void RooJSONFactoryWSTool::exportHisto(RooArgSet const &vars, std::size_t n, dou
13571358
auto &observablesNode = output["axes"].set_seq();
13581359
// axes have to be ordered to get consistent bin indices
13591360
for (auto *var : static_range_cast<RooRealVar *>(vars)) {
1360-
JSONNode &obsNode = observablesNode.append_child().set_map();
13611361
std::string name = var->GetName();
13621362
RooJSONFactoryWSTool::testValidName(name, false);
1363+
JSONNode &obsNode = observablesNode.append_child().set_map();
13631364
obsNode["name"] << name;
13641365
if (var->getBinning().isUniform()) {
13651366
obsNode["min"] << var->getMin();
@@ -1526,6 +1527,7 @@ RooJSONFactoryWSTool::CombinedData RooJSONFactoryWSTool::exportCombinedData(RooA
15261527
void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
15271528
{
15281529
// find category observables
1530+
15291531
RooAbsCategory *cat = nullptr;
15301532
for (RooAbsArg *obs : *data.get()) {
15311533
if (dynamic_cast<RooAbsCategory *>(obs)) {
@@ -1556,15 +1558,14 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
15561558
if (auto weightVar = variables.find("weightVar")) {
15571559
variables.remove(*weightVar);
15581560
}
1559-
1560-
for (RooAbsArg *arg : variables) {
1561-
exportVariable(arg, output["axes"]);
1562-
}
1563-
1561+
15641562
// this is a regular binned dataset
15651563
if (auto dh = dynamic_cast<RooDataHist const *>(&data)) {
15661564
output["type"] << "binned";
1567-
return exportHisto(*dh->get(), dh->numEntries(), dh->weightArray(), output);
1565+
for(auto* var : static_range_cast<RooRealVar*>(variables)){
1566+
_domains->readVariable(*var);
1567+
}
1568+
return exportHisto(variables, dh->numEntries(), dh->weightArray(), output);
15681569
}
15691570

15701571
// Check if this actually represents a binned dataset, and then import it
@@ -1588,26 +1589,16 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
15881589
isBinnedData = true;
15891590
if (isBinnedData) {
15901591
output["type"] << "binned";
1591-
/*std::string datasetName = data.GetName();
1592-
if (datasetName.find("combData_ZvvH126_5") != std::string::npos) {
1593-
file << variables << std::endl;
1594-
for (size_t idx = 0; idx < data.numEntries(); ++idx) {
1595-
file << data.get(idx)->getRealValue("obs_x_ZvvH126_dot_5") << std::endl;
1596-
1597-
}
1598-
// Write the contents vector values to the file
1599-
for (const auto& val : contents) {
1600-
file << val << std::endl;
1601-
}
1602-
}
1603-
1604-
file.close();*/
1592+
for(auto* var : static_range_cast<RooRealVar*>(variables)){
1593+
_domains->readVariable(*var);
1594+
}
16051595
return exportHisto(variables, data.numEntries(), contents.data(), output);
16061596
}
16071597
}
16081598

16091599
// this really is an unbinned dataset
16101600
output["type"] << "unbinned";
1601+
exportVariables(variables, output["axes"]);
16111602
auto &coords = output["entries"].set_seq();
16121603
std::vector<double> weightVals;
16131604
bool hasNonUnityWeights = false;
@@ -1788,32 +1779,45 @@ void RooJSONFactoryWSTool::importDependants(const JSONNode &n)
17881779
}
17891780

17901781
void RooJSONFactoryWSTool::exportModelConfig(JSONNode &rootnode, RooStats::ModelConfig const &mc,
1791-
const std::vector<CombinedData> &combDataSets)
1792-
{
1793-
auto pdf = dynamic_cast<RooSimultaneous const *>(mc.GetPdf());
1794-
if (pdf == nullptr) {
1795-
warning("RooFitHS3 only supports ModelConfigs with RooSimultaneous! Skipping ModelConfig.");
1796-
return;
1797-
}
1798-
1799-
for (std::size_t i = 0; i < std::max(combDataSets.size(), std::size_t(1)); ++i) {
1800-
const bool hasdata = i < combDataSets.size();
1801-
if (hasdata && !matches(combDataSets.at(i), pdf))
1782+
const std::vector<CombinedData> &combDataSets,
1783+
const std::vector<RooAbsData*> &singleDataSets)
1784+
{
1785+
auto pdf = mc.GetPdf();
1786+
auto simpdf = dynamic_cast<RooSimultaneous const *>(pdf);
1787+
if (simpdf){
1788+
for (std::size_t i = 0; i < std::max(combDataSets.size(), std::size_t(1)); ++i) {
1789+
const bool hasdata = i < combDataSets.size();
1790+
if (hasdata && !matches(combDataSets.at(i), simpdf))
18021791
continue;
1803-
1804-
std::string analysisName(pdf->GetName());
1805-
if (hasdata)
1792+
1793+
std::string analysisName(simpdf->GetName());
1794+
if (hasdata)
18061795
analysisName += "_" + combDataSets[i].name;
1807-
1808-
exportSingleModelConfig(rootnode, mc, analysisName, hasdata ? &combDataSets[i].components : nullptr);
1796+
1797+
exportSingleModelConfig(rootnode, mc, analysisName, hasdata ? &combDataSets[i].components : nullptr);
1798+
}
1799+
} else {
1800+
RooArgSet observables(*mc.GetObservables());
1801+
int founddata = 0;
1802+
for(auto* data : singleDataSets){
1803+
if(observables.equals(*(data->get()))){
1804+
std::map<std::string,std::string> mapping;
1805+
mapping[pdf->GetName()] = data->GetName();
1806+
exportSingleModelConfig(rootnode, mc, std::string(pdf->GetName()) + "_" + data->GetName(), &mapping);
1807+
++founddata;
1808+
}
1809+
}
1810+
if(founddata == 0){
1811+
exportSingleModelConfig(rootnode, mc, pdf->GetName(), nullptr);
1812+
}
18091813
}
18101814
}
18111815

18121816
void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats::ModelConfig const &mc,
18131817
std::string const &analysisName,
18141818
std::map<std::string, std::string> const *dataComponents)
18151819
{
1816-
auto pdf = static_cast<RooSimultaneous const *>(mc.GetPdf());
1820+
auto pdf = mc.GetPdf();
18171821

18181822
JSONNode &analysisNode = appendNamedChild(rootnode["analyses"], analysisName);
18191823

@@ -1826,11 +1830,22 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats:
18261830
nllNode["data"].set_seq();
18271831

18281832
if (dataComponents) {
1829-
for (auto const &item : pdf->indexCat()) {
1833+
auto simPdf = static_cast<RooSimultaneous const *>(pdf);
1834+
if(simPdf){
1835+
for (auto const &item : simPdf->indexCat()) {
18301836
const auto &dataComp = dataComponents->find(item.first);
1831-
nllNode["distributions"].append_child() << pdf->getPdf(item.first)->GetName();
1837+
nllNode["distributions"].append_child() << simPdf->getPdf(item.first)->GetName();
18321838
nllNode["data"].append_child() << dataComp->second;
1833-
}
1839+
}
1840+
} else {
1841+
for(auto it:*dataComponents){
1842+
nllNode["distributions"].append_child() << it.first;
1843+
nllNode["data"].append_child() << it.second;
1844+
}
1845+
}
1846+
} else {
1847+
nllNode["distributions"].append_child() << pdf->GetName();
1848+
nllNode["data"].append_child() << 0;
18341849
}
18351850

18361851
if (mc.GetExternalConstraints()) {
@@ -1842,7 +1857,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats:
18421857
}
18431858

18441859
auto writeList = [&](const char *name, RooArgSet const *args) {
1845-
if (!args)
1860+
if (!args || !args->size())
18461861
return;
18471862

18481863
std::vector<std::string> names;
@@ -1857,7 +1872,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats:
18571872

18581873
auto &domainsNode = rootnode["domains"];
18591874

1860-
if (mc.GetNuisanceParameters()) {
1875+
if (mc.GetNuisanceParameters() && mc.GetNuisanceParameters()->size() > 0){
18611876
std::string npDomainName = analysisName + "_nuisance_parameters";
18621877
domains.append_child() << npDomainName;
18631878
RooFit::JSONIO::Detail::Domains::ProductDomain npDomain;
@@ -1867,7 +1882,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats:
18671882
npDomain.writeJSON(appendNamedChild(domainsNode, npDomainName));
18681883
}
18691884

1870-
if (mc.GetGlobalObservables()) {
1885+
if (mc.GetGlobalObservables() && mc.GetGlobalObservables()->size() > 0){
18711886
std::string globDomainName = analysisName + "_global_observables";
18721887
domains.append_child() << globDomainName;
18731888
RooFit::JSONIO::Detail::Domains::ProductDomain globDomain;
@@ -1877,7 +1892,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats:
18771892
globDomain.writeJSON(appendNamedChild(domainsNode, globDomainName));
18781893
}
18791894

1880-
if (mc.GetParametersOfInterest()) {
1895+
if (mc.GetParametersOfInterest() && mc.GetParametersOfInterest()->size() > 0){
18811896
std::string poiDomainName = analysisName + "_parameters_of_interest";
18821897
domains.append_child() << poiDomainName;
18831898
RooFit::JSONIO::Detail::Domains::ProductDomain poiDomain;
@@ -1938,28 +1953,31 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n)
19381953
exportAttributes(arg, n);
19391954
}
19401955

1941-
// export all datasets
1956+
// collect all datasets
19421957
std::vector<RooAbsData *> alldata;
19431958
for (auto &d : _workspace.allData()) {
19441959
alldata.push_back(d);
19451960
}
19461961
sortByName(alldata);
19471962
// first, take care of combined datasets
1963+
std::vector<RooAbsData *> singleData;
19481964
std::vector<RooJSONFactoryWSTool::CombinedData> combData;
19491965
for (auto &d : alldata) {
19501966
auto data = this->exportCombinedData(*d);
19511967
if (!data.components.empty())
19521968
combData.push_back(data);
1969+
else
1970+
singleData.push_back(d);
19531971
}
1954-
// next, take care of regular datasets
1972+
// next, take care datasets
19551973
for (auto &d : alldata) {
19561974
this->exportData(*d);
19571975
}
19581976

19591977
// export all ModelConfig objects and attached Pdfs
19601978
for (TObject *obj : _workspace.allGenericObjects()) {
19611979
if (auto mc = dynamic_cast<RooStats::ModelConfig *>(obj)) {
1962-
exportModelConfig(n, *mc, combData);
1980+
exportModelConfig(n, *mc, combData, singleData);
19631981
}
19641982
}
19651983

0 commit comments

Comments
 (0)