Skip to content

BleckwenAI/xgboost-predictor4j

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

77 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

xgboost-predictor4j

License build Coverage Maven Central Documentation

Bleckwen JVM implementation of XGBoost Predictor

Features

  • Faster than XGboost4j especially on distributed frameworks like Flink or Spark
  • No dependency (no need to install libgomp)
  • Designed for streaming (on-the-fly prediction)
  • Scala and Java APIs with flexible input (Array or FVector)
  • Compatible with XGboost models 0.90 and 1.0.0
  • Support of ML interpretability with fast algorithm (predictApproxContrib) and slower SHAP algorithm (predictContrib)

Limitations

  • Only binary classification (binary:logistic) is supported in this release
  • predictContrib() use SHAP algorithm described in this paper but does not check for duplicate indexes (rewind is not implemented). The impact is negligeable as it happens in very rare situation (a comparison with XGBoots4J performed on 1_000_000 random records did not show any discrepancy)

Release History

  • 1.0 06/07/2020 first version
  • 1.1 12/04/2021 compatibility with 1.4.0 binary files
  • 1.2 01/20/2021 release for Scala 2.12
  • 1.3 26/05/2023 fix: compare float values

Integration

  • Maven
<dependency>
  <groupId>ai.bleckwen</groupId>
  <artifactId>xgboost-predictor4j</artifactId>
  <version>1.0</version>
</dependency>
  • SBT
libraryDependencies += "ai.bleckwen" % "xgboost-predictor4j" % "1.0"

The package was build and published wih Scala 2.12.13 but you can rebuild it with Scala 2.13 by using Maven profile scala213 or by using the Makefile goal.

Using Predictor in Scala

  val bytes = org.apache.commons.io.IOUtils.toByteArray(this.getClass.getResourceAsStream("/path_to.model"))
  val predictor = Predictor(bytes)
  val denseArray = Array(0.23, 0.0, 1.0, 0.5)
  val score = predictor.predict(denseArray).head

Using Predictor in Java

   byte[] bytes = org.apache.commons.io.IOUtils.toByteArray(this.getClass().getResourceAsStream("/path_to.model"));
   Predictor predictor = (new PredictorBuilder()).build(bytes) ;
   double[] denseArray = {0, 0, 32, 0, 0, 16, -8, 0, 0, 0};
   double score = predictor.predict(denseArray)[0];

Benchmarks

See BENCH.md