diff --git a/common/utils.js b/common/utils.js index c28185fa..a76f3ef0 100644 --- a/common/utils.js +++ b/common/utils.js @@ -199,16 +199,12 @@ export function getMedianValue(array) { (array[array.length / 2 - 1] + array[array.length / 2]) / 2; } -// Set tf.js backend based WebNN's 'MLDeviceType' option -export async function setPolyfillBackend(device) { - // Simulate WebNN's device selection using various tf.js backends. - // MLDeviceType: ['default', 'gpu', 'cpu'] - // 'default' or 'gpu': tfjs-backend-webgl, 'cpu': tfjs-backend-wasm - if (!device) device = 'gpu'; - // Use 'webgl' by default for better performance. +// Set tf.js backend +export async function setPolyfillBackend(backend) { + if (!backend) backend = 'webgpu'; + // Use 'webgpu' by default for better performance. // Note: 'wasm' backend may run failed on some samples since // some ops aren't supported on 'wasm' backend at present - const backend = device === 'cpu' ? 'wasm' : 'webgl'; const context = await navigator.ml.createContext(); const tf = context.tf; if (tf) { @@ -221,8 +217,8 @@ export async function setPolyfillBackend(device) { throw new Error(`Failed to set tf.js backend ${backend}.`); } await tf.ready(); - let backendInfo = backend == 'wasm' ? 'WASM' : 'WebGL'; - if (backendInfo == 'WASM') { + let backendInfo = tf.getBackend(); + if (backendInfo == 'wasm') { const hasSimd = tf.env().features['WASM_HAS_SIMD_SUPPORT']; const hasThreads = tf.env().features['WASM_HAS_MULTITHREAD_SUPPORT']; if (hasThreads && hasSimd) { @@ -277,7 +273,7 @@ export function getUrlParams() { } // Set backend for using WebNN-polyfill or WebNN -export async function setBackend(backend, device) { +export async function setBackend(backend, device, polyfillBackend) { const webnnPolyfillId = 'webnn_polyfill'; const webnnNodeId = 'webnn_node'; const webnnPolyfillElem = document.getElementById(webnnPolyfillId); @@ -304,7 +300,7 @@ export async function setBackend(backend, device) { // Create WebNN-polyfill script await loadScript(webnnPolyfillUrl, webnnPolyfillId); } - await setPolyfillBackend(device); + await setPolyfillBackend(polyfillBackend); } else if (backend === 'webnn') { // For Electron if (isElectron()) { @@ -327,7 +323,7 @@ export async function setBackend(backend, device) { } } } else { - addAlert(`Unknow backend: ${backend}`, 'warning'); + addAlert(`Unknown backend: ${backend}`, 'warning'); } } diff --git a/face_recognition/index.html b/face_recognition/index.html index 0a929c3f..f196e436 100644 --- a/face_recognition/index.html +++ b/face_recognition/index.html @@ -32,10 +32,13 @@
+