Running ML models in a game (and in Wasm!)

Using Tract and Bevy, guess the number being drawn.

TL;DR: Try it here - Read the code here

One of my colleague recently starred Tract github repository and got me wondering how easy it would be to use. I know how to create a ONNX model with PyTorch Lightning, or even that there are pretrained models available. Spoilers: it’s very easy to integrate!

The last couple of month, I’ve been doing game jams, Ludum Dare 47 and Game Off 2020 with Bevy, so I wanted to check If I could easily use Tract in a game made with Bevy and build for a Wasm target.

guessing all digits

Running a ONNX model with Tract

Tract example

Using Tract to run a ONNX model is fearly easy, and the example provided is great.

First, you load the model, specifying it’s input:

let model = tract_onnx::onnx()
    // load the model
    .model_for_path("mobilenetv2-1.0.onnx")?
    // specify input type and shape
    .with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224)))?
    // optimize the model
    .into_optimized()?
    // make the model runnable and fix its inputs and outputs
    .into_runnable()?;

Then take an image and transform it to an array with the expected shape and normalize the values:

// open image, resize it and make a Tensor out of it
let image = image::open("grace_hopper.jpg").unwrap().to_rgb8();
let resized =
    image::imageops::resize(&image, 224, 224, ::image::imageops::FilterType::Triangle);
let image: Tensor = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
    let mean = [0.485, 0.456, 0.406][c];
    let std = [0.229, 0.224, 0.225][c];
    (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std
})
.into();

And finally run the model and get the result with the best score:

// run the model on the input
let result = model.run(tvec!(image))?;

// find and display the max value with its index
let best = result[0]
    .to_array_view::<f32>()?
    .iter()
    .cloned()
    .zip(2..)
    .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
println!("result: {:?}", best);

With MNIST

For MNIST, I use the model from ONNX MNIST. It takes an image of 28px by 28px as input, so its shape is (1, 1, 28, 28).

Its output is an array of 10 float numbers, representing the score of each digit. I can get the digit with the best score:

let result = model.model.run(tvec!(image)).unwrap();

if let Some((value, score)) = result[0]
    .to_array_view::<f32>()
    .unwrap()
    .iter()
    .cloned()
    .enumerate()
    .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
    if score > 10. {
        println!("{:?}", value);
    }
}

Game setup

Loading the model

To load the model from an .onnx file, I need to create a custom asset loader for this format. Following bevy custom asset loader example, I first declare my asset type OnnxModel

#[derive(Debug, TypeUuid)]
#[uuid = "578fae90-a8de-41ab-a4dc-3aca66a31eed"]
pub struct OnnxModel {
    pub model: SimplePlan<
        TypedFact,
        Box<dyn TypedOp>,
        tract_onnx::prelude::Graph<TypedFact, Box<dyn TypedOp>>,
    >,
}

And then I implement AssetLoader for OnnxModelLoader. Even though I know the input shape of my model, I do not call with_input_fact in the loader to be independent on the model loaded.

#[derive(Default)]
pub struct OnnxModelLoader;

impl AssetLoader for OnnxModelLoader {
    fn load<'a>(
        &'a self,
        mut bytes: &'a [u8],
        load_context: &'a mut LoadContext,
    ) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
        Box::pin(async move {
            let model = tract_onnx::onnx()
                .model_for_read(&mut bytes)
                .unwrap()
                .into_optimized()?
                .into_runnable()?;

            load_context.set_default_asset(LoadedAsset::new(OnnxModel { model }));
            Ok(())
        })
    }

    fn extensions(&self) -> &[&str] {
        &["onnx"]
    }
}

I create a struct to hold the Handle to the model so that I can reuse the loaded model without needing to reload it every time

struct State {
    model: Handle<OnnxModel>,
}

impl FromResources for State {
    fn from_resources(resources: &Resources) -> Self {
        let asset_server = resources.get::<AssetServer>().unwrap();
        State {
            model: asset_server.load("model.onnx"),
        }
    }
}

And finally, I add my asset loader and my resource to the Bevy app:

App::build();
    .add_asset::<OnnxModel>()
    .init_asset_loader::<OnnxModelLoader>()
    .init_resource::<State>();

Drawing to a texture

To display a texture, I use standart Bevy UI components, here an ImageBundle. I also mark this entity with components Interaction and FocusPolicy from bevy::ui so that it can react to mouse clicks and movements.

To be input type independent (touch and mouse) I create an event Draw that takes a position coordinates. I can then trigger this event on CursorMoved events when mouse is clicked over the texture in a system

fn drawing_mouse(
    (mut reader, events): (Local<EventReader<CursorMoved>>, Res<Events<CursorMoved>>),
    mut last_mouse_position: Local<Option<Vec2>>,
    mut texture_events: ResMut<Events<Event>>,
    state: Res<State>,
    drawable: Query<(&Interaction, &GlobalTransform, &Style), With<Drawable>>,
) {
    for (interaction, transform, style) in drawable.iter() {
        if let Interaction::Clicked = interaction {
            // Get the width and height of the texture
            let width = if let Val::Px(x) = style.size.width {
                x
            } else {
                0.
            };
            let height = if let Val::Px(x) = style.size.height {
                x
            } else {
                0.
            };
            // For every `CursorMoved` event
            for event in reader.iter(&events) {
                if let Some(last_mouse_position) = *last_mouse_position {
                    // If mouvement is fast, interpolate positions between last known position
                    // and current position from event
                    let steps =
                        (last_mouse_position.distance(event.position) as u32 / INPUT_SIZE + 1) * 3;
                    for i in 0..steps {
                        let lerped =
                            last_mouse_position.lerp(event.position, i as f32 / steps as f32);

                        // Change cursor position from window to texture
                        let x = lerped.x - transform.translation.x + width / 2.;
                        let y = lerped.y - transform.translation.y + height / 2.;

                        // And send the event to draw at this position
                        texture_events.send(Event::Draw(Vec2::new(x, y)));
                    }
                } else {
                    let x = event.position.x - transform.translation.x + width / 2.;
                    let y = event.position.y - transform.translation.y + height / 2.;
                    texture_events.send(Event::Draw(Vec2::new(x, y)));
                }

                *last_mouse_position = Some(event.position);
            }
        } else {
            *last_mouse_position = None;
        }
    }
}

And to actually draw on the texture, I listen for this event and use a brush to color the texture around the event coordinates:

fn update_texture(
    (mut reader, events): (Local<EventReader<Event>>, Res<Events<Event>>),
    materials: Res<Assets<ColorMaterial>>,
    mut textures: ResMut<Assets<Texture>>,
    mut state: ResMut<State>,
    drawable: Query<(&bevy::ui::Node, &Handle<ColorMaterial>), With<Drawable>>,
) {
    for event in reader.iter(&events) {
        // Retrieving the texture data from it's `Handle`
        // First, getting the `Handle<ColorMaterial>` of the `ImageBundle`
        let (node, mat) = drawable.iter().next().unwrap();
        // Then, getting the `ColorMaterial` matching this handle
        let material = materials.get(mat).unwrap();
        // Finally, getting the texture itself from the `texture` field of the `ColorMaterial`
        let texture = textures
            .get_mut(material.texture.as_ref().unwrap())
            .unwrap();

        match event {
            Event::Draw(pos) => {
                // Use a large round brush instead of drawing pixel by pixel
                // `node.size` is the displayed size of the texture
                // `texture.size` is the actual size of the texture data
                // `INPUT_SIZE` is the expected input size by the model
                // The brush will be bigger if the drawing area is bigger to provide
                // a smoother drawing experience
                let radius = (1.3 * node.size.x / INPUT_SIZE as f32 / 2.) as i32;
                let scale = (texture.size.width as f32 / node.size.x) as i32;
                for i in -radius..(radius + 1) {
                    for j in -radius..(radius + 1) {
                        let target_point = Vec2::new(pos.x + i as f32, pos.y + j as f32);
                        if pos.distance(target_point) < radius as f32 {
                            for i in 0..=scale {
                                for j in 0..=scale {
                                    set_pixel(
                                        (target_point.x as i32) * scale + i,
                                        ((node.size.y as f32 - target_point.y) as i32) * scale + j,
                                        255,
                                        texture,
                                    )
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}

Getting model input from texture, and infering digit

I can now run my model on my texture and guess the digit!. This model is fast enough that this can run at every frame

fn infer(
    state: Res<State>,
    materials: Res<Assets<ColorMaterial>>,
    textures: Res<Assets<Texture>>,
    models: Res<Assets<OnnxModel>>,
    drawable: Query<&Handle<ColorMaterial>, With<Drawable>>,
    mut display: Query<&mut Text>,
) {
    for mat in drawable.iter() {
        // Get the texture from the `Handle<ColorMaterial>`
        let material = materials.get(mat).unwrap();
        let texture = textures.get(material.texture.as_ref().unwrap()).unwrap();

        // As the texture is much larger than the model input, each point in the
        // model input will be 1 if at least half of the point in a square
        // of `pixel_size`in the texture are colored
        let pixel_size = (texture.size.width as u32 / INPUT_SIZE) as i32;

        let image = tract_ndarray::Array4::from_shape_fn(
            (1, 1, INPUT_SIZE as usize, INPUT_SIZE as usize),
            |(_, _, y, x)| {
                let mut val = 0;
                for i in 0..pixel_size as i32 {
                    for j in 0..pixel_size as i32 {
                        val += get_pixel(
                            x as i32 * pixel_size + i,
                            y as i32 * pixel_size + j,
                            texture,
                        ) as i32;
                    }
                }
                if val > pixel_size * pixel_size / 2 {
                    1. as f32
                } else {
                    0. as f32
                }
            },
        )
        .into();

        if let Some(model) = models.get(state.model.as_weak::<OnnxModel>()) {
            // Run the model on the input
            let result = model.model.run(tvec!(image)).unwrap();

            // Get the best prediction, and display it if its score is high enough
            if let Some((value, score)) = result[0]
                .to_array_view::<f32>()
                .unwrap()
                .iter()
                .cloned()
                .enumerate()
                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
            {
                if score > 10. {
                    display.iter_mut().next().unwrap().value = format!("{:?}", value);
                } else {
                    display.iter_mut().next().unwrap().value = "".to_string();
                }
            }
        }
    }
}

Build for Wasm target

Thanks to bevy_webgl2, this is actually very straightforward. I just need to add the plugin WebGL2Plugin and disable the default features of Bevy to only enable the one available on Wasm.

There is a bevy webgl2 template if you want to build a game that can works on all platforms.

Github Actions to deploy to itch.io

To host this POC (and my other games) on itch.io, I have a workflow on Github Actions that build the game for Windows, Linux, macOS and Wasm, create a release on Github, and push everything to itch.io. This workflow is triggered on tags.

It is made of two jobs. The first one will, on each platform:

  • setup the environment, install dependencies and tools
  • build the game
  • perform platform specific steps (strip, wasm-bindgen, …) and copy assets if needed
  • create an archive for the platform (dmg for macOS, zip for the other)
  • create a release on Github and add those archives to the release
  • save the archives as artifacts

The second job will take the artifacts from the first job and send them to itch.io using butler, itch.io official command line