diff --git a/java/README.md b/java/README.md index 2a80df0dffc48..4d39df5576ff7 100644 --- a/java/README.md +++ b/java/README.md @@ -9,7 +9,17 @@ TBD: maven distribution The minimum supported Java Runtime is version 8. -An example implementation is located in [src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java) +An example implementation is located in +[src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java). +Once compiled the sample code expects the following arguments `ScoreMNIST + `. MNIST is expected +to be in libsvm format. If the optional scikit-learn flag is supplied the model +is expected to be produced by skl2onnx (so expects a flat feature vector, and +produces a structured output), otherwise the model is expected to be a CNN from +pytorch (expecting a `[1][1][28][28]` input, producing a vector of +probabilities). Two example models are provided in [testdata](testdata), +`cnn_mnist_pytorch.onnx` and `lr_mnist_scikit.onnx`. The first is a LeNet5 style +CNN trained using PyTorch, the second is a logistic regression trained using scikit-learn. This project can be built manually using the instructions below. diff --git a/java/src/test/java/sample/ScoreMNIST.java b/java/src/test/java/sample/ScoreMNIST.java index 27ef903a09694..5ad40a5bec4ed 100644 --- a/java/src/test/java/sample/ScoreMNIST.java +++ b/java/src/test/java/sample/ScoreMNIST.java @@ -12,15 +12,16 @@ import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; -import java.io.BufferedInputStream; -import java.io.FileInputStream; +import java.io.BufferedReader; +import java.io.FileReader; import java.io.IOException; -import java.io.ObjectInputStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.regex.Pattern; /** * Demo code, supporting both a pytorch CNN trained on MNIST and a scikit-learn model trained on @@ -29,6 +30,8 @@ public class ScoreMNIST { private static final Logger logger = Logger.getLogger(ScoreMNIST.class.getName()); + /** Pattern for splitting libsvm format files. */ + private static final Pattern splitPattern = Pattern.compile("\\s+"); /** A named tuple for sparse classification data. */ private static class SparseData { @@ -38,29 +41,112 @@ private static class SparseData { public SparseData(int[] labels, List indices, List values) { this.labels = labels; - this.indices = indices; - this.values = values; + this.indices = Collections.unmodifiableList(indices); + this.values = Collections.unmodifiableList(values); } } /** - * Deserialises the data and puts it in a named tuple. + * Converts a List of Integer into an int array. + * + * @param list The list to convert. + * @return The int array. + */ + private static int[] convertInts(List list) { + int[] output = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + output[i] = list.get(i); + } + return output; + } + + /** + * Converts a List of Float into a float array. + * + * @param list The list to convert. + * @return The float array. + */ + private static float[] convertFloats(List list) { + float[] output = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + output[i] = list.get(i); + } + return output; + } + + /** + * Loads data from a libsvm format file. * * @param path The path to load the data from. * @return A named tuple containing the data. * @throws IOException If it failed to read the file. - * @throws ClassNotFoundException If a class wasn't found (only uses JDK types so this would be - * very odd). */ - @SuppressWarnings("unchecked") - private static SparseData load(String path) throws IOException, ClassNotFoundException { - try (ObjectInputStream ois = - new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { - int[] labels = (int[]) ois.readObject(); - List indices = (List) ois.readObject(); - List values = (List) ois.readObject(); - return new SparseData(labels, indices, values); + private static SparseData load(String path) throws IOException { + int pos = 0; + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + List labels = new ArrayList<>(); + String line; + int maxFeatureID = Integer.MIN_VALUE; + try (BufferedReader reader = new BufferedReader(new FileReader(path))) { + for (; ; ) { + line = reader.readLine(); + if (line == null) { + break; + } + pos++; + String[] fields = splitPattern.split(line); + int lastID = -1; + try { + boolean valid = true; + List curIndices = new ArrayList<>(); + List curValues = new ArrayList<>(); + for (int i = 1; i < fields.length && valid; i++) { + int ind = fields[i].indexOf(':'); + if (ind < 0) { + logger.warning(String.format("Weird line at %d", pos)); + valid = false; + } + String ids = fields[i].substring(0, ind); + int id = Integer.parseInt(ids); + curIndices.add(id); + if (maxFeatureID < id) { + maxFeatureID = id; + } + float val = Float.parseFloat(fields[i].substring(ind + 1)); + curValues.add(val); + if (id <= lastID) { + logger.warning(String.format("Repeated features at line %d", pos)); + valid = false; + } else { + lastID = id; + } + } + if (valid) { + // Store the label + labels.add(Integer.parseInt(fields[0])); + // Store the features + indices.add(convertInts(curIndices)); + values.add(convertFloats(curValues)); + } else { + throw new IOException("Invalid LibSVM format file at line " + pos); + } + } catch (NumberFormatException ex) { + logger.warning(String.format("Weird line at %d", pos)); + throw new IOException("Invalid LibSVM format file", ex); + } + } } + + logger.info( + "Loaded " + + maxFeatureID + + " features, " + + labels.size() + + " samples, from + '" + + path + + "'."); + return new SparseData(convertInts(labels), indices, values); } /** @@ -170,11 +256,10 @@ public static int pred(float[] probabilities) { return idx; } - public static void main(String[] args) throws OrtException, IOException, ClassNotFoundException { + public static void main(String[] args) throws OrtException, IOException { if (args.length < 2 || args.length > 3) { System.out.println("Usage: ScoreMNIST "); - System.out.println( - "The test data input format is a Java serialized file containing an array of int labels, a list of int[] feature indices, and a list of float[] feature values"); + System.out.println("The test data input should be a libsvm format version of MNIST."); return; } @@ -232,7 +317,7 @@ public static void main(String[] args) throws OrtException, IOException, ClassNo confusionMatrix[data.labels[i]][predLabel]++; - if (i % 500 == 0) { + if (i % 2000 == 0) { logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (i + 1)); logger.log(Level.INFO, "Output type = " + output.get(0).toString()); if (args.length == 3) { diff --git a/java/testdata/cnn_mnist_pytorch.onnx b/java/testdata/cnn_mnist_pytorch.onnx new file mode 100644 index 0000000000000..91df4267b9ae6 Binary files /dev/null and b/java/testdata/cnn_mnist_pytorch.onnx differ diff --git a/java/testdata/lr_mnist_scikit.onnx b/java/testdata/lr_mnist_scikit.onnx new file mode 100644 index 0000000000000..0cdc8efb80227 Binary files /dev/null and b/java/testdata/lr_mnist_scikit.onnx differ