This crate demonstrates how to run an MNIST-trained model in the browser for inference.
-
Build
./build-for-web.sh {backend}The backend can either be
flexorwgpu. Note thatwgpuonly works for browsers with support for WebGPU. -
Run the server
./run-server.sh
-
Open the
http://localhost:8000/in the browser.
The inference components of burn with the flex backend can be built with #![no_std]. This
makes it possible to build and run the model with the wasm32-unknown-unknown target without a
special system library, such as WASI. (See Cargo.toml on how to
include burn dependencies without std).
For this demo, we use trained parameters (model.bin) and model (model.rs) from the
burn MNIST example.
The inference API for JavaScript is exposed with the help of
wasm-bindgen's library and tools.
JavaScript (index.js) is used to transform hand-drawn digits to a format that the inference API
accepts. The transformation includes image cropping, scaling down, and converting it to grayscale
values.
Layers:
- Input Image (28,28, 1ch)
Conv2d(3x3, 64ch),BatchNorm2d,Gelu,MaxPool(2x2)Conv2d(3x3, 16ch),BatchNorm2d,Gelu,MaxPool(2x2)Linear(1600, 128),ReluLinear(128, 128),ReluLinear(128, 10)- Softmax Output
The total number of parameters is 260,810.
The model is trained with 18 epochs and the final test accuracy is 95.83%.
Random transformations are used for data augmentation.
The training and hyper parameter information in can be found in
burn MNIST example.
The main differentiating factor of this example's approach (compiling rust model into wasm) and
other popular tools, such as TensorFlow.js,
ONNX Runtime JS and
TVM Web is the absence of runtime code. The rust
compiler optimizes and includes only used burn routines. 1,509,747 bytes out of Wasm's 1,866,491
byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including
burn's nn components, the data deserialization library, and math operations).
There are several planned enhancements in place:
Two online MNIST demos inspired and helped build this demo: MNIST Draw by Marc (@mco-gh) and MNIST Web Demo (no code was copied but helped tremendously with an implementation approach).