diff --git a/common/utils.js b/common/utils.js
index c28185fa..a76f3ef0 100644
--- a/common/utils.js
+++ b/common/utils.js
@@ -199,16 +199,12 @@ export function getMedianValue(array) {
(array[array.length / 2 - 1] + array[array.length / 2]) / 2;
}
-// Set tf.js backend based WebNN's 'MLDeviceType' option
-export async function setPolyfillBackend(device) {
- // Simulate WebNN's device selection using various tf.js backends.
- // MLDeviceType: ['default', 'gpu', 'cpu']
- // 'default' or 'gpu': tfjs-backend-webgl, 'cpu': tfjs-backend-wasm
- if (!device) device = 'gpu';
- // Use 'webgl' by default for better performance.
+// Set tf.js backend
+export async function setPolyfillBackend(backend) {
+ if (!backend) backend = 'webgpu';
+ // Use 'webgpu' by default for better performance.
// Note: 'wasm' backend may run failed on some samples since
// some ops aren't supported on 'wasm' backend at present
- const backend = device === 'cpu' ? 'wasm' : 'webgl';
const context = await navigator.ml.createContext();
const tf = context.tf;
if (tf) {
@@ -221,8 +217,8 @@ export async function setPolyfillBackend(device) {
throw new Error(`Failed to set tf.js backend ${backend}.`);
}
await tf.ready();
- let backendInfo = backend == 'wasm' ? 'WASM' : 'WebGL';
- if (backendInfo == 'WASM') {
+ let backendInfo = tf.getBackend();
+ if (backendInfo == 'wasm') {
const hasSimd = tf.env().features['WASM_HAS_SIMD_SUPPORT'];
const hasThreads = tf.env().features['WASM_HAS_MULTITHREAD_SUPPORT'];
if (hasThreads && hasSimd) {
@@ -277,7 +273,7 @@ export function getUrlParams() {
}
// Set backend for using WebNN-polyfill or WebNN
-export async function setBackend(backend, device) {
+export async function setBackend(backend, device, polyfillBackend) {
const webnnPolyfillId = 'webnn_polyfill';
const webnnNodeId = 'webnn_node';
const webnnPolyfillElem = document.getElementById(webnnPolyfillId);
@@ -304,7 +300,7 @@ export async function setBackend(backend, device) {
// Create WebNN-polyfill script
await loadScript(webnnPolyfillUrl, webnnPolyfillId);
}
- await setPolyfillBackend(device);
+ await setPolyfillBackend(polyfillBackend);
} else if (backend === 'webnn') {
// For Electron
if (isElectron()) {
@@ -327,7 +323,7 @@ export async function setBackend(backend, device) {
}
}
} else {
- addAlert(`Unknow backend: ${backend}`, 'warning');
+ addAlert(`Unknown backend: ${backend}`, 'warning');
}
}
diff --git a/face_recognition/index.html b/face_recognition/index.html
index 0a929c3f..f196e436 100644
--- a/face_recognition/index.html
+++ b/face_recognition/index.html
@@ -32,10 +32,13 @@
+