Using Tract and Bevy, guess the number being drawn.

One of my colleague recently starred Tract github repository and got me wondering how easy it would be to use. I know how to create a ONNX model with PyTorch Lightning, or even that there are pretrained models available. Spoilers: it’s very easy to integrate!

The last couple of month, I’ve been doing game jams, Ludum Dare 47 and Game Off 2020 with Bevy, so I wanted to check If I could easily use Tract in a game made with Bevy and build for a Wasm target. ## Running a ONNX model with Tract

### Tract example

Using Tract to run a ONNX model is fearly easy, and the example provided is great.

First, you load the model, specifying it’s input:

``````let model = tract_onnx::onnx()
// load the model
.model_for_path("mobilenetv2-1.0.onnx")?
// specify input type and shape
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224)))?
// optimize the model
.into_optimized()?
// make the model runnable and fix its inputs and outputs
.into_runnable()?;
``````

Then take an image and transform it to an array with the expected shape and normalize the values:

``````// open image, resize it and make a Tensor out of it
let image = image::open("grace_hopper.jpg").unwrap().to_rgb8();
let resized =
image::imageops::resize(&image, 224, 224, ::image::imageops::FilterType::Triangle);
let image: Tensor = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
let mean = [0.485, 0.456, 0.406][c];
let std = [0.229, 0.224, 0.225][c];
(resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std
})
.into();
``````

And finally run the model and get the result with the best score:

``````// run the model on the input
let result = model.run(tvec!(image))?;

// find and display the max value with its index
let best = result
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(2..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
println!("result: {:?}", best);
``````

### With MNIST

For MNIST, I use the model from ONNX MNIST. It takes an image of 28px by 28px as input, so its shape is `(1, 1, 28, 28)`.

Its output is an array of 10 float numbers, representing the score of each digit. I can get the digit with the best score:

``````let result = model.model.run(tvec!(image)).unwrap();

if let Some((value, score)) = result
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
if score > 10. {
println!("{:?}", value);
}
}
``````

## Game setup

### Loading the model

To load the model from an `.onnx` file, I need to create a custom asset loader for this format. Following bevy custom asset loader example, I first declare my asset type `OnnxModel`

``````#[derive(Debug, TypeUuid)]
#[uuid = "578fae90-a8de-41ab-a4dc-3aca66a31eed"]
pub struct OnnxModel {
pub model: SimplePlan<
TypedFact,
Box<dyn TypedOp>,
tract_onnx::prelude::Graph<TypedFact, Box<dyn TypedOp>>,
>,
}
``````

And then I implement `AssetLoader` for `OnnxModelLoader`. Even though I know the input shape of my model, I do not call `with_input_fact` in the loader to be independent on the model loaded.

``````#[derive(Default)]
pub struct OnnxModelLoader;

impl AssetLoader for OnnxModelLoader {
fn load<'a>(
&'a self,
mut bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let model = tract_onnx::onnx()
.model_for_read(&mut bytes)
.unwrap()
.into_optimized()?
.into_runnable()?;

load_context.set_default_asset(LoadedAsset::new(OnnxModel { model }));
Ok(())
})
}

fn extensions(&self) -> &[&str] {
&["onnx"]
}
}
``````

I create a `struct` to hold the `Handle` to the model so that I can reuse the loaded model without needing to reload it every time

``````struct State {
model: Handle<OnnxModel>,
}

impl FromResources for State {
fn from_resources(resources: &Resources) -> Self {
let asset_server = resources.get::<AssetServer>().unwrap();
State {
model: asset_server.load("model.onnx"),
}
}
}
``````

And finally, I add my asset loader and my resource to the Bevy app:

``````App::build();
.add_asset::<OnnxModel>()
.init_asset_loader::<OnnxModelLoader>()
.init_resource::<State>();
``````

### Drawing to a texture

To display a texture, I use standart Bevy UI components, here an `ImageBundle`. I also mark this entity with components `Interaction` and `FocusPolicy` from `bevy::ui` so that it can react to mouse clicks and movements.

To be input type independent (touch and mouse) I create an event `Draw` that takes a position coordinates. I can then trigger this event on `CursorMoved` events when mouse is clicked over the texture in a system

``````fn drawing_mouse(
(mut reader, events): (Local<EventReader<CursorMoved>>, Res<Events<CursorMoved>>),
mut last_mouse_position: Local<Option<Vec2>>,
mut texture_events: ResMut<Events<Event>>,
state: Res<State>,
drawable: Query<(&Interaction, &GlobalTransform, &Style), With<Drawable>>,
) {
for (interaction, transform, style) in drawable.iter() {
if let Interaction::Clicked = interaction {
// Get the width and height of the texture
let width = if let Val::Px(x) = style.size.width {
x
} else {
0.
};
let height = if let Val::Px(x) = style.size.height {
x
} else {
0.
};
// For every `CursorMoved` event
for event in reader.iter(&events) {
if let Some(last_mouse_position) = *last_mouse_position {
// If mouvement is fast, interpolate positions between last known position
// and current position from event
let steps =
(last_mouse_position.distance(event.position) as u32 / INPUT_SIZE + 1) * 3;
for i in 0..steps {
let lerped =
last_mouse_position.lerp(event.position, i as f32 / steps as f32);

// Change cursor position from window to texture
let x = lerped.x - transform.translation.x + width / 2.;
let y = lerped.y - transform.translation.y + height / 2.;

// And send the event to draw at this position
texture_events.send(Event::Draw(Vec2::new(x, y)));
}
} else {
let x = event.position.x - transform.translation.x + width / 2.;
let y = event.position.y - transform.translation.y + height / 2.;
texture_events.send(Event::Draw(Vec2::new(x, y)));
}

*last_mouse_position = Some(event.position);
}
} else {
*last_mouse_position = None;
}
}
}
``````

And to actually draw on the texture, I listen for this event and use a brush to color the texture around the event coordinates:

``````fn update_texture(
(mut reader, events): (Local<EventReader<Event>>, Res<Events<Event>>),
materials: Res<Assets<ColorMaterial>>,
mut textures: ResMut<Assets<Texture>>,
mut state: ResMut<State>,
drawable: Query<(&bevy::ui::Node, &Handle<ColorMaterial>), With<Drawable>>,
) {
for event in reader.iter(&events) {
// Retrieving the texture data from it's `Handle`
// First, getting the `Handle<ColorMaterial>` of the `ImageBundle`
let (node, mat) = drawable.iter().next().unwrap();
// Then, getting the `ColorMaterial` matching this handle
let material = materials.get(mat).unwrap();
// Finally, getting the texture itself from the `texture` field of the `ColorMaterial`
let texture = textures
.get_mut(material.texture.as_ref().unwrap())
.unwrap();

match event {
Event::Draw(pos) => {
// Use a large round brush instead of drawing pixel by pixel
// `node.size` is the displayed size of the texture
// `texture.size` is the actual size of the texture data
// `INPUT_SIZE` is the expected input size by the model
// The brush will be bigger if the drawing area is bigger to provide
// a smoother drawing experience
let radius = (1.3 * node.size.x / INPUT_SIZE as f32 / 2.) as i32;
let scale = (texture.size.width as f32 / node.size.x) as i32;
for i in -radius..(radius + 1) {
for j in -radius..(radius + 1) {
let target_point = Vec2::new(pos.x + i as f32, pos.y + j as f32);
if pos.distance(target_point) < radius as f32 {
for i in 0..=scale {
for j in 0..=scale {
set_pixel(
(target_point.x as i32) * scale + i,
((node.size.y as f32 - target_point.y) as i32) * scale + j,
255,
texture,
)
}
}
}
}
}
}
}
}
}
``````

### Getting model input from texture, and infering digit

I can now run my model on my texture and guess the digit!. This model is fast enough that this can run at every frame

``````fn infer(
state: Res<State>,
materials: Res<Assets<ColorMaterial>>,
textures: Res<Assets<Texture>>,
models: Res<Assets<OnnxModel>>,
drawable: Query<&Handle<ColorMaterial>, With<Drawable>>,
mut display: Query<&mut Text>,
) {
for mat in drawable.iter() {
// Get the texture from the `Handle<ColorMaterial>`
let material = materials.get(mat).unwrap();
let texture = textures.get(material.texture.as_ref().unwrap()).unwrap();

// As the texture is much larger than the model input, each point in the
// model input will be 1 if at least half of the point in a square
// of `pixel_size`in the texture are colored
let pixel_size = (texture.size.width as u32 / INPUT_SIZE) as i32;

let image = tract_ndarray::Array4::from_shape_fn(
(1, 1, INPUT_SIZE as usize, INPUT_SIZE as usize),
|(_, _, y, x)| {
let mut val = 0;
for i in 0..pixel_size as i32 {
for j in 0..pixel_size as i32 {
val += get_pixel(
x as i32 * pixel_size + i,
y as i32 * pixel_size + j,
texture,
) as i32;
}
}
if val > pixel_size * pixel_size / 2 {
1. as f32
} else {
0. as f32
}
},
)
.into();

if let Some(model) = models.get(state.model.as_weak::<OnnxModel>()) {
// Run the model on the input
let result = model.model.run(tvec!(image)).unwrap();

// Get the best prediction, and display it if its score is high enough
if let Some((value, score)) = result
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
if score > 10. {
display.iter_mut().next().unwrap().value = format!("{:?}", value);
} else {
display.iter_mut().next().unwrap().value = "".to_string();
}
}
}
}
}
``````

## Build for Wasm target

Thanks to bevy_webgl2, this is actually very straightforward. I just need to add the plugin `WebGL2Plugin` and disable the default features of Bevy to only enable the one available on Wasm.

There is a bevy webgl2 template if you want to build a game that can works on all platforms.

## Github Actions to deploy to itch.io

To host this POC (and my other games) on itch.io, I have a workflow on Github Actions that build the game for Windows, Linux, macOS and Wasm, create a release on Github, and push everything to itch.io. This workflow is triggered on tags.

It is made of two jobs. The first one will, on each platform:

• setup the environment, install dependencies and tools
• build the game
• perform platform specific steps (`strip`, `wasm-bindgen`, …) and copy assets if needed
• create an archive for the platform (`dmg` for macOS, `zip` for the other)
• create a release on Github and add those archives to the release
• save the archives as artifacts

The second job will take the artifacts from the first job and send them to itch.io using butler, itch.io official command line