Skip to content

Commit

Permalink
Add WebGPU ep (does not link yet).
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelmenges committed Mar 10, 2025
1 parent eb4cd90 commit 8a174a9
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/wasm-emscripten/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.0.0"
edition = "2021"

[dependencies]
ort = { path = "../../" }
ort = { path = "../../", default-features = false, features = ["ndarray", "webgpu"] }
ndarray = "0.16"
image = "0.25"

Expand Down
4 changes: 2 additions & 2 deletions examples/wasm-emscripten/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ Example how to use `ort` to run `onnxruntime` in the Web with multi-threaded exe
1. Install the Rust nightly toolchain with `rustup install nightly`.
1. Add Emscripten as Rust target with `rustup target add wasm32-unknown-emscripten --toolchain nightly`.
1. Clone Emscripten SDK via `git clone https://github.com/emscripten-core/emsdk.git --depth 1`.
1. Install Emscripten SDK 3.1.59 locally to [match version used in ONNX runtime](https://github.com/microsoft/onnxruntime/blob/1d97d6ef55433298dee58634b0ea59f736e8a72e/.gitmodules#L10) via `./emsdk/emsdk install 3.1.59`.
1. Prepare local Emscripten SDK via `./emsdk/emsdk activate 3.1.59`.
1. Install Emscripten SDK 4.0.3 locally to [match version used in ONNX runtime](https://github.com/microsoft/onnxruntime/blob/754ee21f83518bf127ba481cf1bedf58ee3b5374/.gitmodules#L10) via `./emsdk/emsdk install 4.0.3`.
1. Prepare local Emscripten SDK via `./emsdk/emsdk activate 4.0.3`.

Environment tested on Ubuntu 24.04 and macOS 14.7.1.

Expand Down
8 changes: 4 additions & 4 deletions examples/wasm-emscripten/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn main() {
// Download precompiled libonnxruntime.a.
{
// Request archive.
let mut request = get("https://github.com/alfatraining/onnxruntime-wasm-builds/releases/download/v1.20.1/libonnxruntime-v1.20.1-wasm.zip")
let mut request = get("https://github.com/alfatraining/ort-artifacts-staging/releases/download/754ee21/ort_static-main-wasm32-unknown-emscripten.zip")
.expect("Cannot request precompiled onnxruntime.");
let mut buf = Vec::<u8>::new();
request.read_to_end(&mut buf).expect("Cannot read precompiled onnxruntime.");
Expand All @@ -40,11 +40,11 @@ fn main() {
let mut zip = ZipArchive::new(reader).expect("Cannot incept unzipper.");

// Extract precompiled library.
// TODO: For debug builds, link to a debug build of onnxruntime.
{
let mut buf = Vec::<u8>::new();
let mut mode_title_case = mode.to_string();
mode_title_case = format!("{}{mode_title_case}", mode_title_case.remove(0).to_uppercase());
zip.by_name(format!("{mode_title_case}/libonnxruntime.a").as_str())

zip.by_name("onnxruntime/lib/libonnxruntime.a")
.expect("Cannot find precompiled onnxruntime.")
.read_to_end(&mut buf)
.expect("Cannot read precompiled onnxruntime.");
Expand Down
18 changes: 15 additions & 3 deletions examples/wasm-emscripten/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ pub extern "C" fn dealloc(ptr: *mut std::os::raw::c_void, size: usize) {
pub extern "C" fn detect_objects(ptr: *const u8, width: u32, height: u32) {
ort::init()
.with_global_thread_pool(ort::environment::GlobalThreadPoolOptions::default())
.with_execution_providers([ort::execution_providers::cpu::CPUExecutionProvider::default().build()])
.commit()
.expect("Cannot initialize ort.");

let mut session = ort::session::Session::builder()
let mut builder = ort::session::Session::builder()
.expect("Cannot create Session builder.")
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)
.expect("Cannot optimize graph.")
Expand All @@ -54,7 +53,20 @@ pub extern "C" fn detect_objects(ptr: *const u8, width: u32, height: u32) {
.with_intra_threads(2)
.expect("Cannot set intra thread count.")
.with_inter_threads(1)
.expect("Cannot set inter thread count.")
.expect("Cannot set inter thread count.");

let use_webgpu = true; // TODO: Make `use_webgpu` a parameter of `detect_objects`? Or say in README to change it here.
if use_webgpu {
use ort::execution_providers::ExecutionProvider;
let ep = ort::execution_providers::WebGPUExecutionProvider::default();
if ep.is_available().expect("Cannot check for availability of WebGPU ep.") {
ep.register(&mut builder).expect("Cannot register WebGPU ep.");
} else {
println!("WebGPU ep is not available.");
}
}

let mut session = builder
.commit_from_memory(include_bytes!("../yolov8m.onnx"))
.expect("Cannot commit model.");

Expand Down

0 comments on commit 8a174a9

Please sign in to comment.