diff --git a/code/samples/matmul.js b/code/samples/matmul.js index 03d54698..4759c5d6 100644 --- a/code/samples/matmul.js +++ b/code/samples/matmul.js @@ -1,19 +1,17 @@ -const context = await navigator.ml.createContext(); - -// The following code multiplies matrix a [3, 4] with matrix b [4, 3] -// into matrix c [3, 3]. +// Step 0: Create a context and graph builder for 'gpu', 'cpu' or 'npu'. +const context = await navigator.ml.createContext({deviceType: 'gpu'}); const builder = new MLGraphBuilder(context); -const descA = {dataType: 'float32', dimensions: [3, 4]}; -const a = builder.input('a', descA); -const descB = {dataType: 'float32', dimensions: [4, 3]}; -const bufferB = new Float32Array(sizeOfShape(descB.dimensions)).fill(0.5); -const b = builder.constant(descB, bufferB); -const c = builder.gemm(a, b); - +// Step 1: Create a computational graph calculating `c = a * b`. +const a = builder.input('a', {dataType: 'float32', dimensions: [3, 4]}); +const b = builder.input('b', {dataType: 'float32', dimensions: [4, 3]}); +const c = builder.matmul(a, b); +// Step 2: Compile it into an executable graph. const graph = await builder.build({c}); -const bufferA = new Float32Array(sizeOfShape(descA.dimensions)).fill(0.5); -const bufferC = new Float32Array(sizeOfShape([3, 3])); -const inputs = {'a': bufferA}; -const outputs = {'c': bufferC}; -const results = await context.compute(graph, inputs, outputs); +// Step 3: Bind input and output buffers to the graph and execute. +const bufferA = new Float32Array(3*4).fill(1.0); +const bufferB = new Float32Array(4*3).fill(0.8); +const bufferC = new Float32Array(3*3); +const results = await context.compute( + graph, {'a': bufferA, 'b': bufferB}, {'c': bufferC}); +// Step 4: Retrieve the results. console.log(`values: ${results.outputs.c}`); diff --git a/code/samples_repo.js b/code/samples_repo.js index 161dd863..9e36c3c9 100644 --- a/code/samples_repo.js +++ b/code/samples_repo.js @@ -1,8 +1,8 @@ // The samples under ./samples folder const samples = [ + 'matmul.js', 'mul_add.js', 'simple_graph.js', - 'matmul.js', ]; class SamplesRepository {