-
Notifications
You must be signed in to change notification settings - Fork 0
Home
At Bleckwen we use Machine Learning within streaming applications to detect frauds on financial transactions. Our objectives are both to minimize the end-to-end latency especially for Instant Payments and to increase the global throughput (number of transactions processed by seconds).
Initially we used XGboost4J but this package performs poorly in a distributed JVM environnement such as Flink and was not adapted to predictions on-the-fly (no need for a DMatrix
features bur rather a simple Vector). Additionally the dependency on library libgomp
is a major deployment constraint on a shared infrastructure, and the Java Native Interface (JNI) used by XGboost4J adds an overhead and results to very fluctuent response times.
We also looked at alternatives such as https://github.com/h2oai/xgboost-predictor but identified a couple of problems and found a major functional gap: AI interpretability (method predictContrib
) was missing.
Note also that we use a small number of features (max 100)
The purpose was to have a simple and fast Xgboost predictor for JVM without any external dependency and with support of AI interpretability.
The package provides two implementation of AI interpretability:
-
predictContrib()
: that uses the SHAP algorithm similarly to XGboost4J. This algorithm is a bit slow as it requires a walkthrough across all possible paths of all decisions trees. -
predictApproxContrib()
that uses a faster (x 10) but inaccurate interpretation algorithm for Random Forests http://blog.datadive.net/interpreting-random-forests/ (it also exists in XGboost but not exposed in Java interface).
The code is written in Scala which gives a very concise and clear implementation. The design is fully object oriented and does not not use Integer Arrays and special tricks like bit shift operators.
The APIs support both Scala and Java Collections (Array
and Map
) as input/output. For instance, if you call the API with a Java Map or Array you will get a Java Array as a result. If you use Scala Map, you'll get a Scala Array.
Also we've added more unit tests and a comparison unit test with XGBoost4J to ensure predictor integrity with different models version.
Yes, as shown in the BENCH.md the predict is in average twice faster but we obtained significant gains in a distributed environment (Kafka, Flink) with a gain of x6.
The maximum throughput (number of transactions that the system can process keeping end to end latency in the 300 ms constraints) were multiplied by at least 4 and the CPU utilization was clearly reduced.
See the small examples included in the README. Basically this very close to XGBoost4J and migrating from its should be straightforward.
We made this project Open Source because we trust and promote XGboost which is probably today the only ML algorithm that can satisfy our business requirements i.e. on the fly predictions and explainable AI. Also we would like to share our expertise on the domain and get any valuable feedback.
If you need to enhance this project, feel free to open an issue or create a branch but remember: the key criteria is performance here! (use the Makefile goal bench
to benchmark any change).