Skip to content

Commit

Permalink
[Java] Tidying up the sample MNIST code (microsoft#3824)
Browse files Browse the repository at this point in the history
* Updating the Java sample to load MNIST in libsvm format.
* java - code formatting fix.
Co-authored-by: Adam Pocock <[email protected]>
  • Loading branch information
yuslepukhin authored May 5, 2020
1 parent f7ff5a7 commit 5db30a4
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 21 deletions.
12 changes: 11 additions & 1 deletion java/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ TBD: maven distribution

The minimum supported Java Runtime is version 8.

An example implementation is located in [src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java)
An example implementation is located in
[src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java).
Once compiled the sample code expects the following arguments `ScoreMNIST
<path-to-mnist-model> <path-to-mnist> <scikit-learn-flag>`. MNIST is expected
to be in libsvm format. If the optional scikit-learn flag is supplied the model
is expected to be produced by skl2onnx (so expects a flat feature vector, and
produces a structured output), otherwise the model is expected to be a CNN from
pytorch (expecting a `[1][1][28][28]` input, producing a vector of
probabilities). Two example models are provided in [testdata](testdata),
`cnn_mnist_pytorch.onnx` and `lr_mnist_scikit.onnx`. The first is a LeNet5 style
CNN trained using PyTorch, the second is a logistic regression trained using scikit-learn.

This project can be built manually using the instructions below.

Expand Down
125 changes: 105 additions & 20 deletions java/src/test/java/sample/ScoreMNIST.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/**
* Demo code, supporting both a pytorch CNN trained on MNIST and a scikit-learn model trained on
Expand All @@ -29,6 +30,8 @@
public class ScoreMNIST {

private static final Logger logger = Logger.getLogger(ScoreMNIST.class.getName());
/** Pattern for splitting libsvm format files. */
private static final Pattern splitPattern = Pattern.compile("\\s+");

/** A named tuple for sparse classification data. */
private static class SparseData {
Expand All @@ -38,29 +41,112 @@ private static class SparseData {

public SparseData(int[] labels, List<int[]> indices, List<float[]> values) {
this.labels = labels;
this.indices = indices;
this.values = values;
this.indices = Collections.unmodifiableList(indices);
this.values = Collections.unmodifiableList(values);
}
}

/**
* Deserialises the data and puts it in a named tuple.
* Converts a List of Integer into an int array.
*
* @param list The list to convert.
* @return The int array.
*/
private static int[] convertInts(List<Integer> list) {
int[] output = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
output[i] = list.get(i);
}
return output;
}

/**
* Converts a List of Float into a float array.
*
* @param list The list to convert.
* @return The float array.
*/
private static float[] convertFloats(List<Float> list) {
float[] output = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
output[i] = list.get(i);
}
return output;
}

/**
* Loads data from a libsvm format file.
*
* @param path The path to load the data from.
* @return A named tuple containing the data.
* @throws IOException If it failed to read the file.
* @throws ClassNotFoundException If a class wasn't found (only uses JDK types so this would be
* very odd).
*/
@SuppressWarnings("unchecked")
private static SparseData load(String path) throws IOException, ClassNotFoundException {
try (ObjectInputStream ois =
new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
int[] labels = (int[]) ois.readObject();
List<int[]> indices = (List<int[]>) ois.readObject();
List<float[]> values = (List<float[]>) ois.readObject();
return new SparseData(labels, indices, values);
private static SparseData load(String path) throws IOException {
int pos = 0;
List<int[]> indices = new ArrayList<>();
List<float[]> values = new ArrayList<>();
List<Integer> labels = new ArrayList<>();
String line;
int maxFeatureID = Integer.MIN_VALUE;
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
for (; ; ) {
line = reader.readLine();
if (line == null) {
break;
}
pos++;
String[] fields = splitPattern.split(line);
int lastID = -1;
try {
boolean valid = true;
List<Integer> curIndices = new ArrayList<>();
List<Float> curValues = new ArrayList<>();
for (int i = 1; i < fields.length && valid; i++) {
int ind = fields[i].indexOf(':');
if (ind < 0) {
logger.warning(String.format("Weird line at %d", pos));
valid = false;
}
String ids = fields[i].substring(0, ind);
int id = Integer.parseInt(ids);
curIndices.add(id);
if (maxFeatureID < id) {
maxFeatureID = id;
}
float val = Float.parseFloat(fields[i].substring(ind + 1));
curValues.add(val);
if (id <= lastID) {
logger.warning(String.format("Repeated features at line %d", pos));
valid = false;
} else {
lastID = id;
}
}
if (valid) {
// Store the label
labels.add(Integer.parseInt(fields[0]));
// Store the features
indices.add(convertInts(curIndices));
values.add(convertFloats(curValues));
} else {
throw new IOException("Invalid LibSVM format file at line " + pos);
}
} catch (NumberFormatException ex) {
logger.warning(String.format("Weird line at %d", pos));
throw new IOException("Invalid LibSVM format file", ex);
}
}
}

logger.info(
"Loaded "
+ maxFeatureID
+ " features, "
+ labels.size()
+ " samples, from + '"
+ path
+ "'.");
return new SparseData(convertInts(labels), indices, values);
}

/**
Expand Down Expand Up @@ -170,11 +256,10 @@ public static int pred(float[] probabilities) {
return idx;
}

public static void main(String[] args) throws OrtException, IOException, ClassNotFoundException {
public static void main(String[] args) throws OrtException, IOException {
if (args.length < 2 || args.length > 3) {
System.out.println("Usage: ScoreMNIST <model-path> <test-data> <optional:scikit-learn-flag>");
System.out.println(
"The test data input format is a Java serialized file containing an array of int labels, a list of int[] feature indices, and a list of float[] feature values");
System.out.println("The test data input should be a libsvm format version of MNIST.");
return;
}

Expand Down Expand Up @@ -232,7 +317,7 @@ public static void main(String[] args) throws OrtException, IOException, ClassNo

confusionMatrix[data.labels[i]][predLabel]++;

if (i % 500 == 0) {
if (i % 2000 == 0) {
logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (i + 1));
logger.log(Level.INFO, "Output type = " + output.get(0).toString());
if (args.length == 3) {
Expand Down
Binary file added java/testdata/cnn_mnist_pytorch.onnx
Binary file not shown.
Binary file added java/testdata/lr_mnist_scikit.onnx
Binary file not shown.

0 comments on commit 5db30a4

Please sign in to comment.