forked from microsoft/EdgeML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBonsaiLocalDriver.cpp
46 lines (33 loc) · 1.41 KB
/
BonsaiLocalDriver.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "Bonsai.h"
using namespace EdgeML;
using namespace EdgeML::Bonsai;
int main(int argc, char **argv)
{
#ifdef LINUX
trapfpe();
struct sigaction sa;
sigemptyset (&sa.sa_mask);
sa.sa_flags = SA_SIGINFO;
sa.sa_sigaction = fpehandler;
sigaction (SIGFPE, &sa, NULL);
#endif
assert (sizeof(MKL_INT) == sizeof(Eigen::Index));
std::string dataDir;
std::string currResultsPath;
BonsaiTrainer trainer(DataIngestType::FileIngest, argc, (const char**) argv,
dataDir, currResultsPath);
auto modelBytes = trainer.getModelSize(); // This can be changed to getSparseModelSize() if you need to export sparse model
auto model = new char[modelBytes];
auto meanVarBytes = trainer.getMeanVarSize();
auto meanVar = new char[meanVarBytes];
trainer.exportModel(modelBytes, model, currResultsPath); // use exportSparseModel(...) if you need sparse model
trainer.exportMeanVar(meanVarBytes, meanVar, currResultsPath);
trainer.dumpModelMeanVar(currResultsPath);
BonsaiPredictor predictor(modelBytes, model); // use the constructor predictor(modelBytes, model, false) for loading a sparse model.
predictor.importMeanVar(meanVarBytes, meanVar);
predictor.batchEvaluate(trainer.data.Xtest, trainer.data.Ytest, dataDir, currResultsPath);
delete[] model, meanVar;
return 0;
}