Deploying machine learning models outside of a Python environment used to be difficult. When the target platform is the browser, the defacto standard for serving predictions has been an API call to a server-side inference engine. For many reasons, server-side inference is a non-optimal solution and machine learning models are more often being deployed natively. TensorFlow` has done a good job at supporting this movement by providing cross-platform APIs, however many of us do not want to be married to a single ecosystem.
In comes the Open Neural Network Exchange (ONNX) project which, since being picked up by Microsoft, has been seeing massive development efforts and is approaching a stable state. It's now easier than ever to deploy machine-learning models, trained using your machine-learning framework of choice, on your platform of choice, with hardware acceleration out of the box! In April this year, onnxruntime-web
was introduced (see this Pull Request). onnxruntime-web
uses WebAssembly to compile the onnxruntime
inference engine to run ONNX models in the browser - it's about WebAssembly time starts to flex its muscles. Especially when paired with WebGL, we suddenly have GPU-powered machine learning in the browser, pretty cool.
In this tutorial we will dive into onnxruntime-web
by deploying a pre-trained PyTorch model to the browser. We will be using AlexNet as our deployment target. AlexNet has been trained as an image classifier on the ImageNet dataset, so we will be building an image classifier - nothing better than re-inventing the wheel. At the end of this tutorial, we will have built a bundled web app that can be run as a stand alone static web page or integrated into your JavaScript framework of choice.
Learn from code instead -> onnxruntime-web-tutorial
You will need a trained machine-learning model exported as an ONNX binary protobuf file. There's many ways to achieve this using a number of different deep-learning frameworks. For the sake of this tutorial, I will be using the exported model from the AlexNet example in the PyTorch documentation, the python code snippet below will help you generate your own model. You can also follow the linked documentation to export your own PyTorch model. If you're coming from Tensorflow, this tutorial will help you with exporting your model to ONNX. Lastly, ONNX doesn't just pride itself on cross-platform deployment, but also in allowing exports from all major deep-learning frameworks. Those of you using another deep learning framework should be able to find support for exporting to ONNX in the docs of your framework.
import torch
import torchvision
dummy_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.alexnet(pretrained=True)
input_names = ["input1"]
output_names = ["output1"]
torch.onnx.export(
model,
dummy_input,
"alexnet.onnx",
verbose=True,
input_names=input_names,
output_names=output_names
)
Running this file creates a file, alexnet.onnx
, a binary protobuf file which contains both the network structure and parameters of the model you exported (in this case, AlexNet).
ONNX Runtime Web is a JavaScript library for running ONNX models on the browser and on Node.js. ONNX Runtime Web has adopted WebAssembly and WebGL technologies for providing an optimized ONNX model inference runtime for both CPUs and GPUs.
The official package is hosted on npm under the name onnxruntime-web
. When using a bundler or working server-side, this package can be installed using npm install
. However, it's also possible to deliver the code via a CDN using a script tag. The bundling process is a bit more involved so we will start with the script tag approach and come back to using the npm package later.
Let's start with the core application: model inference. onnxruntime
exposes a runtime object called an InferenceSession
with a method .run()
which is used to initiate the forward pass with the desired inputs. Both the InferenceSessesion
constructor and the accompanying .run()
method return a Promise
so we will run the entire process inside an async
context. Before implementing any browser elements, we will check that our model runs with a dummy input tensor, remembering the input and output names and sizes that we defined earlier when exporting the model.
async function run() {
try {
// create a new session and load the AlexNet model.
const session = await ort.InferenceSession.create("./alexnet.onnx")
// prepare dummy input data
const dims = [1, 3, 224, 224]
const size = dims[0] * dims[1] * dims[2] * dims[3]
const inputData = Float32Array.from({ length: size }, () => Math.random())
// prepare feeds. use model input names as keys.
const feeds = { input1: new ort.Tensor("float32", inputData, dims) }
// feed inputs and run
const results = await session.run(feeds)
console.log(results.output1.data)
} catch (e) {
console.log(e)
}
}
run()
We then implement a simple HTML template, index.html
, which should load both the pre-compiled onnxruntime-web
package and main.js
, containing our code.
<!DOCTYPE html>
<html>
<header>
<title>ONNX Runtime Web Tutorial</title>
</header>
<body>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script src="main.js"></script>
</body>
</html>
Finally, to run this we can use live-server
. If you haven't started an npm
project by now, please do so by running npm init
in your current working directory. Once you've completed the setup, install live-server (npm install live-server
) and serve the static HTML page using npx light-server -s . -p 8080
.
Congratulations! You're now running a machine learning model natively in the browser. To check that everything is running fine simply go to your console and make sure that the output tensor is logged (AlexNet is pretty big so it's normal that inference takes a few seconds).
Next we will use webpack
to bundle our dependencies as would be the case if we want to deploy the model in a Javascript app powered by frameworks like React or Vue. Usually bundling is a relatively simple procedure, however onnxruntime-web
requires a slightly more involved webpack
configuration because WebAssembly is used to provide the natively assembled runtime.
Ah, the classic pitfall, especially when working with cutting-edge web technology. If your intended users are not using one of the four major browsers (Chrome, Edge, Firefox, Safari) you might want to hold off on integrating WebAssembly components. More information on the WebAssembly support and roadmap can be found here.
The following steps are based on the examples provided by the official ONNX documentation. We're assuming you've already started an npm project.
npm install onnxruntime-web && npm install -D webpack webpack-cli copy-webpack
main.js
to use the new package instead of loading the onnxruntime-web
module via a CDN. This is done by updating main.js
with a one-liner at the start of the script.const ort = require('onnxruntime-web');
webpack.config.js
.// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
const path = require("path")
const CopyPlugin = require("copy-webpack-plugin")
module.exports = () => {
return {
target: ["web"],
entry: path.resolve(__dirname, "main.js"),
output: {
path: path.resolve(__dirname, "dist"),
filename: "bundle.min.js",
library: {
type: "umd",
},
},
plugins: [
new CopyPlugin({
// Use copy plugin to copy *.wasm to output folder.
patterns: [
{
from: "node_modules/onnxruntime-web/dist/*.wasm",
to: "[name][ext]",
},
],
}),
],
mode: "production",
}
}
npx webpack
to compile the bundle.index.html
.ort.min.js
script tag to stop loading the compiled package from jsDelivrbundle.min.js
(which contains all our dependencies bundled and minified) instead of main.js
index.html
should now look something like this.
<!DOCTYPE html>
<html>
<header>
<title>ONNX Runtime Web Tutorial</title>
</header>
<body>
<script src="bundle.min.js.js"></script>
</body>
</html>
To make building and launching the live server easier, you could define build
and serve
scripts in package.json
"scripts": {
"build": "npx webpack",
"serve": "npm run build && npx light-server -s . -p 8080"
}
Let's put this model to work and implement the image classification pipeline.
We will need some utility functions to load, resize, and display the image - the canvas
object is perfect for this. In addition, image classification systems typically have lots of magic built into the pre-processing pipeline, this is quite trivial to implement in Python using frameworks like numpy
, unfortunately this is not the case JavaScript. It follows that we will have to implement our pre-processing from scratch to transform our image data into the correct tensor format.
We will need some HTML elements to interact with and display the data.
<label for="fileIn"><h2>What am I?</h2></label>
<input type="file" id="file-in" name="file-in" />
<img id="input-image" class="input-image"></img>
<img id="scaled-image" class="scaled-image"></img>
<h3 id="target"></h3>
We want to load an image from file and display it - going back to main.js
, we will get the file input element and use FileReader
to read the data into memory. Following this, the image data will be passed to handleImage
which will draw the image using the canvas
context.
const canvas = document.createElement("canvas"),
ctx = canvas.getContext("2d")
document.getElementById("file-in").onchange = function (evt) {
let target = evt.target || window.event.src,
files = target.files
if (FileReader && files && files.length) {
var fileReader = new FileReader()
fileReader.onload = () => onLoadImage(fileReader)
fileReader.readAsDataURL(files[0])
}
}
function onLoadImage(fileReader) {
var img = document.getElementById("input-image")
img.onload = () => handleImage(img)
img.src = fileReader.result
}
function handleImage(img) {
ctx.drawImage(img, 0, 0)
}
Now that we can load and display an image, we want to move to extracting and processing the data. Remember that our model takes in a matrix of shape [1, 3, 224, 224]
, this means we will probably have to resize the image and perhaps also transpose the dimensions depending on how we extract the image data.
To resize and extract image data, we will use the canvas
context again. Let's define a function processImage
that does this. processImage
has the necessary elements in scope to immediately draw the scaled image so we will also do that here.
function processImage(img, width) {
const canvas = document.createElement("canvas"),
ctx = canvas.getContext("2d")
// resize image
canvas.width = width
canvas.height = canvas.width * (img.height / img.width)
// draw scaled image
ctx.drawImage(img, 0, 0, canvas.width, canvas.height)
document.getElementById("scaled-image").src = canvas.toDataURL()
// return data
return ctx.getImageData(0, 0, width, width).data
}
We can now add a line to the function handleImage
which calls processImage
.
const resizedImageData = processImage(img, targetWidth)
Finally, let's implement a function called imageDataToTensor
which applies the transforms needed to get the image data ready to be used as input to the model. imageDataToTensor
should apply three transforms:
ctx.getImageData
returns data in the shape [224, 224, 3]
so we need to transpose the data to the shape [3, 224, 224]
ctx.getImageData
returns a UInt8ClampedArray
with int
values ranging 0 to 255, we need to convert the values to float32
and store them in a Float32Array
to construct our tensor input.function imageDataToTensor(data, dims) {
// 2. transpose from [224, 224, 3] -> [3, 224, 224]
const [R, G, B] = [[], [], []]
for (let i = 0; i < data.length; i += 4) {
R.push(data[i])
G.push(data[i + 1])
B.push(data[i + 2])
// here we skip data[i + 3] corresponding to the alpha channel
}
const transposedData = R.concat(G).concat(B)
// convert to float32
let i,
l = transposedData.length
const float32Data = new Float32Array(3 * 224 * 224)
for (i = 0; i < l; i++) {
float32Data[i] = transposedData[i] / 255.0
}
const inputTensor = new ort.Tensor("float32", float32Data, dims)
return inputTensor
}
Almost there, let’s wrap up some loose ends to get the full inference pipeline up and running.
First, stitch together the image processing and inference pipeline in handleImageData
.
function handleImage(img, targetWidth) {
ctx.drawImage(img, 0, 0)
const resizedImageData = processImage(img, targetWidth)
const inputTensor = imageDataToTensor(resizedImageData, DIMS)
run(inputTensor)
}
The output of the model is a list of activation values corresponding to the probability that a certain class is identified in the image. We need to get the most likely classification result by getting the index of the maximum value in the output data, this is done using an argMax
function.
function argMax(arr) {
let max = arr[0]
let maxIndex = 0
for (var i = 1; i < arr.length; i++) {
if (arr[i] > max) {
maxIndex = i
max = arr[i]
}
}
return [max, maxIndex]
}
Finally, run()
needs to be re-factored to accept a tensor input. We also need to use the max index to actually retrieve the results from a list of ImageNet classes. I've pre-converted this list to JSON and we will load it into our script using require
- you can find the JSON file in the code repository linked at the bottom.
const classes = require("./imagenet_classes.json").data
async function run(inputTensor) {
try {
const session = await ort.InferenceSession.create("./alexnet.onnx")
const feeds = { input1: inputTensor }
const results = await session.run(feeds)
const [maxValue, maxIndex] = argMax(results.output1.data)
target.innerHTML = `${classes[maxIndex]}`
} catch (e) {
console.error(e) // non-fatal error handling
}
}
All that's left is to re-build our bundle and serve the app.
That’s it, we’ve built a web app with a machine-learning model running natively in the browser! You can find the full code (including styles and layout) in this code repository on GitHub. I appreciate any and all feedback so feel free to share any Issues or Stars.
Thank you for reading!