Using Tract and Bevy, guess the number being drawn.
TL;DR: Try it here - Read the code here
One of my colleague recently starred Tract github repository and got me wondering how easy it would be to use. I know how to create a ONNX model with PyTorch Lightning, or even that there are pretrained models available. Spoilers: it’s very easy to integrate!
The last couple of month, I’ve been doing game jams, Ludum Dare 47 and Game Off 2020 with Bevy, so I wanted to check If I could easily use Tract in a game made with Bevy and build for a Wasm target.
Running a ONNX model with Tract
Tract example
Using Tract to run a ONNX model is fearly easy, and the example provided is great.
First, you load the model, specifying it’s input:
let model = tract_onnx::onnx()
// load the model
.model_for_path("mobilenetv2-1.0.onnx")?
// specify input type and shape
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224)))?
// optimize the model
.into_optimized()?
// make the model runnable and fix its inputs and outputs
.into_runnable()?;
Then take an image and transform it to an array with the expected shape and normalize the values:
// open image, resize it and make a Tensor out of it
let image = image::open("grace_hopper.jpg").unwrap().to_rgb8();
let resized =
image::imageops::resize(&image, 224, 224, ::image::imageops::FilterType::Triangle);
let image: Tensor = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
let mean = [0.485, 0.456, 0.406][c];
let std = [0.229, 0.224, 0.225][c];
(resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std
})
.into();
And finally run the model and get the result with the best score:
// run the model on the input
let result = model.run(tvec!(image))?;
// find and display the max value with its index
let best = result[0]
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(2..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
println!("result: {:?}", best);
With MNIST
For MNIST, I use the model from
ONNX MNIST.
It takes an image of 28px by 28px as input, so its shape is (1, 1, 28, 28)
.
Its output is an array of 10 float numbers, representing the score of each digit. I can get the digit with the best score:
let result = model.model.run(tvec!(image)).unwrap();
if let Some((value, score)) = result[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
if score > 10. {
println!("{:?}", value);
}
}
Game setup
Loading the model
To load the model from an .onnx
file, I need to create a custom asset loader for this format. Following
bevy custom asset loader example,
I first declare my
asset type OnnxModel
#[derive(Debug, TypeUuid)]
#[uuid = "578fae90-a8de-41ab-a4dc-3aca66a31eed"]
pub struct OnnxModel {
pub model: SimplePlan<
TypedFact,
Box<dyn TypedOp>,
tract_onnx::prelude::Graph<TypedFact, Box<dyn TypedOp>>,
>,
}
And then I implement AssetLoader
for
OnnxModelLoader
.
Even though I know the input shape of my model, I do not call with_input_fact
in the loader to be independent on
the model loaded.
#[derive(Default)]
pub struct OnnxModelLoader;
impl AssetLoader for OnnxModelLoader {
fn load<'a>(
&'a self,
mut bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let model = tract_onnx::onnx()
.model_for_read(&mut bytes)
.unwrap()
.into_optimized()?
.into_runnable()?;
load_context.set_default_asset(LoadedAsset::new(OnnxModel { model }));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["onnx"]
}
}
I create a struct
to hold the Handle
to the model so that I can reuse the loaded model without needing
to reload it every time
struct State {
model: Handle<OnnxModel>,
}
impl FromResources for State {
fn from_resources(resources: &Resources) -> Self {
let asset_server = resources.get::<AssetServer>().unwrap();
State {
model: asset_server.load("model.onnx"),
}
}
}
And finally, I add my asset loader and my resource to the Bevy app:
App::build();
.add_asset::<OnnxModel>()
.init_asset_loader::<OnnxModelLoader>()
.init_resource::<State>();
Drawing to a texture
To display a texture,
I use standart Bevy UI components, here an ImageBundle
. I also mark this entity with components Interaction
and
FocusPolicy
from bevy::ui
so that it can react to mouse clicks and movements.
To be input type independent (touch and mouse) I create an event Draw
that takes a position coordinates.
I can then trigger this event on CursorMoved
events when mouse is clicked over the texture in a
system
fn drawing_mouse(
(mut reader, events): (Local<EventReader<CursorMoved>>, Res<Events<CursorMoved>>),
mut last_mouse_position: Local<Option<Vec2>>,
mut texture_events: ResMut<Events<Event>>,
state: Res<State>,
drawable: Query<(&Interaction, &GlobalTransform, &Style), With<Drawable>>,
) {
for (interaction, transform, style) in drawable.iter() {
if let Interaction::Clicked = interaction {
// Get the width and height of the texture
let width = if let Val::Px(x) = style.size.width {
x
} else {
0.
};
let height = if let Val::Px(x) = style.size.height {
x
} else {
0.
};
// For every `CursorMoved` event
for event in reader.iter(&events) {
if let Some(last_mouse_position) = *last_mouse_position {
// If mouvement is fast, interpolate positions between last known position
// and current position from event
let steps =
(last_mouse_position.distance(event.position) as u32 / INPUT_SIZE + 1) * 3;
for i in 0..steps {
let lerped =
last_mouse_position.lerp(event.position, i as f32 / steps as f32);
// Change cursor position from window to texture
let x = lerped.x - transform.translation.x + width / 2.;
let y = lerped.y - transform.translation.y + height / 2.;
// And send the event to draw at this position
texture_events.send(Event::Draw(Vec2::new(x, y)));
}
} else {
let x = event.position.x - transform.translation.x + width / 2.;
let y = event.position.y - transform.translation.y + height / 2.;
texture_events.send(Event::Draw(Vec2::new(x, y)));
}
*last_mouse_position = Some(event.position);
}
} else {
*last_mouse_position = None;
}
}
}
And to actually draw on the texture, I listen for this event and use a brush to color the texture around the event coordinates:
fn update_texture(
(mut reader, events): (Local<EventReader<Event>>, Res<Events<Event>>),
materials: Res<Assets<ColorMaterial>>,
mut textures: ResMut<Assets<Texture>>,
mut state: ResMut<State>,
drawable: Query<(&bevy::ui::Node, &Handle<ColorMaterial>), With<Drawable>>,
) {
for event in reader.iter(&events) {
// Retrieving the texture data from it's `Handle`
// First, getting the `Handle<ColorMaterial>` of the `ImageBundle`
let (node, mat) = drawable.iter().next().unwrap();
// Then, getting the `ColorMaterial` matching this handle
let material = materials.get(mat).unwrap();
// Finally, getting the texture itself from the `texture` field of the `ColorMaterial`
let texture = textures
.get_mut(material.texture.as_ref().unwrap())
.unwrap();
match event {
Event::Draw(pos) => {
// Use a large round brush instead of drawing pixel by pixel
// `node.size` is the displayed size of the texture
// `texture.size` is the actual size of the texture data
// `INPUT_SIZE` is the expected input size by the model
// The brush will be bigger if the drawing area is bigger to provide
// a smoother drawing experience
let radius = (1.3 * node.size.x / INPUT_SIZE as f32 / 2.) as i32;
let scale = (texture.size.width as f32 / node.size.x) as i32;
for i in -radius..(radius + 1) {
for j in -radius..(radius + 1) {
let target_point = Vec2::new(pos.x + i as f32, pos.y + j as f32);
if pos.distance(target_point) < radius as f32 {
for i in 0..=scale {
for j in 0..=scale {
set_pixel(
(target_point.x as i32) * scale + i,
((node.size.y as f32 - target_point.y) as i32) * scale + j,
255,
texture,
)
}
}
}
}
}
}
}
}
}
Getting model input from texture, and infering digit
I can now run my model on my texture and guess the digit!. This model is fast enough that this can run at every frame
fn infer(
state: Res<State>,
materials: Res<Assets<ColorMaterial>>,
textures: Res<Assets<Texture>>,
models: Res<Assets<OnnxModel>>,
drawable: Query<&Handle<ColorMaterial>, With<Drawable>>,
mut display: Query<&mut Text>,
) {
for mat in drawable.iter() {
// Get the texture from the `Handle<ColorMaterial>`
let material = materials.get(mat).unwrap();
let texture = textures.get(material.texture.as_ref().unwrap()).unwrap();
// As the texture is much larger than the model input, each point in the
// model input will be 1 if at least half of the point in a square
// of `pixel_size`in the texture are colored
let pixel_size = (texture.size.width as u32 / INPUT_SIZE) as i32;
let image = tract_ndarray::Array4::from_shape_fn(
(1, 1, INPUT_SIZE as usize, INPUT_SIZE as usize),
|(_, _, y, x)| {
let mut val = 0;
for i in 0..pixel_size as i32 {
for j in 0..pixel_size as i32 {
val += get_pixel(
x as i32 * pixel_size + i,
y as i32 * pixel_size + j,
texture,
) as i32;
}
}
if val > pixel_size * pixel_size / 2 {
1. as f32
} else {
0. as f32
}
},
)
.into();
if let Some(model) = models.get(state.model.as_weak::<OnnxModel>()) {
// Run the model on the input
let result = model.model.run(tvec!(image)).unwrap();
// Get the best prediction, and display it if its score is high enough
if let Some((value, score)) = result[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
if score > 10. {
display.iter_mut().next().unwrap().value = format!("{:?}", value);
} else {
display.iter_mut().next().unwrap().value = "".to_string();
}
}
}
}
}
Build for Wasm target
Thanks to bevy_webgl2, this is actually
very straightforward. I just need to
add the plugin WebGL2Plugin
and disable the default features of Bevy to only enable the one available on Wasm.
There is a bevy webgl2 template if you want to build a game that can works on all platforms.
Github Actions to deploy to itch.io
To host this POC (and my other games) on itch.io, I have a workflow on Github Actions that build the game for Windows, Linux, macOS and Wasm, create a release on Github, and push everything to itch.io. This workflow is triggered on tags.
It is made of two jobs. The first one will, on each platform:
- setup the environment, install dependencies and tools
- build the game
- perform platform specific steps (
strip
,wasm-bindgen
, …) and copy assets if needed - create an archive for the platform (
dmg
for macOS,zip
for the other) - create a release on Github and add those archives to the release
- save the archives as artifacts
The second job will take the artifacts from the first job and send them to itch.io using butler, itch.io official command line