Chat about this codebase

AI-powered code exploration

Online

1 Project Overview

Burn is a high‐performance, flexible deep learning framework written in Rust. It combines Rust’s safety and concurrency guarantees with GPU‐ and hardware‐specific optimizations to support end‐to‐end model development—from research experiments to production deployments on CPU, GPU, WebAssembly, and embedded devices.

Core Values

  • Performance
    • Automatic kernel fusion
    • Asynchronous execution engine
    • Hardware‐specific backends (CUDA, OpenCL, Vulkan)
  • Flexibility
    • Modular crates for tensors, layers, optimizers, schedulers
    • Custom training loops and callbacks
    • Import/export of standard formats (ONNX, TensorFlow)
  • Safety
    • Rust’s memory and thread safety
    • Zero‐cost abstractions for parallelism
  • Extensibility
    • Pluggable backends (WASM, embedded accelerators)
    • User‐defined layers and operations

Position in the ML Ecosystem

Burn sits at the intersection of research and production:

  • Researchers leverage low‐level control for novel architectures and optimizations.
  • Engineers deploy models across servers, browsers (WASM), and edge devices.
  • System developers extend or integrate custom hardware backends.

High‐Level Repository Structure

Defined in Cargo.toml as a Cargo workspace, the repository organizes functionality into focused crates and examples:

/
├── Cargo.toml             # Workspace config: members, dependencies, profiles
├── README.md              # Project overview, getting started, contribute
├── CITATION.cff           # Citation metadata for academic use
├── POEM.md                # Inspirational overview of Burn’s vision
├── burn-book/             # Static site source for detailed guides
│   ├── src/
│   │   ├── overview.md
│   │   └── motivation.md
│   └── ...
├── crates/                # Core library crates
│   ├── burn-core/         # Core tensor types and operations
│   ├── burn-tensor/       # Multidimensional array backend
│   ├── burn-engine/       # Execution engine and schedulers
│   ├── burn-vision/       # Vision models and utilities
│   └── ...                
└── examples/              # End-to-end training & inference scripts
    ├── cifar10-training/
    ├── wasm-inference/
    └── ...

Workspace Configuration (Cargo.toml snippet)

Developers use this file to track crate versions, workspace policies, and build profiles.

[workspace]
members = [
  "crates/burn-core",
  "crates/burn-tensor",
  "crates/burn-engine",
  "crates/burn-vision",
  "examples/*",
  "burn-book"
]

[profile.release]
opt-level = "z"
lto = true

Use cargo build --workspace to compile all crates, or target individual members with -p <crate-name>. Continuous integration and dependency updates rely on this centralized configuration.

2 Getting Started

This guide takes you from zero to a running “hello-tensor” example and your first training job using Burn’s CI-tested examples. You’ll install Rust, enable backend features, leverage Cargo aliases, and run both the guide and MNIST examples.

2.1 Prerequisites

  • Rust 1.70+ (via rustup)
  • Optional GPU drivers for WGPU (Vulkan/Metal) and CUDA/Torch backends
  • Git

2.2 Clone and Build Burn

git clone https://github.com/tracel-ai/burn.git
cd burn
# Build core library (CPU backend via ndarray)
cargo build --release

2.3 Configure Cargo Aliases

Burn provides aliases in .cargo/config.toml to streamline common tasks:

[alias]
xtask      = "run --package xtask --"
run-checks = "xtask run-checks --release"
  • cargo xtask <cmd> runs custom scripts in the xtask crate
  • cargo run-checks executes formatting, linting, docs, tests and example runs

2.4 Enabling Backend Features

Burn supports multiple backends. Enable them via --features:

# CPU (ndarray), Torch (CUDA), WGPU (Vulkan/Metal)
cargo build --release \
  --features="backend-ndarray,backend-tch,backend-wgpu"

2.5 Hello-Tensor Example

Create a minimal Rust project that prints a tensor:

cargo new hello-tensor
cd hello-tensor

In Cargo.toml, add Burn as a path dependency:

[dependencies]
burn = { path = "../burn", features = ["backend-ndarray"] }

Replace src/main.rs with:

use burn::tensor::Tensor;

fn main() {
    // Create a 2×2 tensor of f32
    let t = Tensor::<f32, 2>::from([[1., 2.], [3., 4.]]);
    println!("Hello, Tensor:\n{}", t);
}

Run:

cargo run --release

2.6 Running CI-Tested Examples

Burn’s repository includes two guided examples—guide and mnist—complete with training and inference workflows. Use cargo run-checks to build and validate both on all enabled backends:

# From repo root: runs formatting, linting, docs, tests, and examples
cargo run-checks

2.6.1 Guide Example

cd examples/guide
# Train on CPU
cargo run -- train --device cpu --batch-size 32
# Inference on CPU
cargo run -- infer --device cpu --input path/to/image.png
# Print model summary
cargo run -- print-model

2.6.2 MNIST Example

cd examples/mnist
# Train with ndarray (CPU)
cargo run -- --backend ndarray --data ./data/mnist/train.csv --train
# Train with tch (CUDA)
cargo run -- --backend tch --data ./data/mnist/train.csv --train

This setup gives you a working Burn installation, a “hello-tensor” starter, and two fully CI-tested examples to explore training and inference across backends.

3 Core Concepts

This section introduces Burn’s fundamental abstractions: tensors, backends, modules, and automatic differentiation. Understanding these concepts helps you build and train models efficiently across different hardware.

3.1 Tensors

Burn’s Tensor is the primary data container for n-dimensional arrays. It is generic over:

  • A backend implementing Backend (CPU, GPU, autodiff).
  • A dimensionality D (compile-time).
  • An element type (Float, Int, etc.).
  • A device (CPU, CUDA, …).

Creating Tensors

use burn::tensor::{Tensor, TensorData};
use burn::backend::NdArray;    // CPU backend
use burn::tensor::dtype::Float;

// 1. Pick a device
let device = Default::default();

// 2. From primitive array (inferred dtype Float<f32>)
let t1: Tensor<NdArray, 2> = Tensor::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);

// 3. From TensorData for custom dtype or slice
let raw = vec![10i32, 20, 30, 40];
let data = TensorData::from(raw.as_slice());
let t2: Tensor<NdArray, 2, _> = Tensor::from_data(data.reshape([2, 2]), &device);

// 4. Inspect and retrieve
let borrowed = t1.to_data();          // &TensorData
let owned = t1.clone().into_data();   // consumes tensor

Common Operations

// Element-wise math
let sum = t1.add(&t2);        // t1 + t2
let scaled = t1.mul_scalar(0.5);

// Reductions
let total = sum.sum_all();    // f32
let max   = sum.max_dim(1);   // Tensor<…, 1>

3.2 Backends

Backends provide device-specific implementations of tensor operations via the Backend trait. Burn ships with:

  • NdArray (burn::backend::NdArray): CPU reference implementation.
  • Cuda (burn_cuda::Cuda): GPU via CubeCL.
  • Autodiff (burn_autodiff::Autodiff<B>): Wraps any Backend to enable gradient tracking.

Switching Backends

// CPU
use burn::backend::NdArray as Cpu;

// GPU
use burn_cuda::{Cuda, CudaDevice};
let device_gpu: CudaDevice = Default::default(); // picks first CUDA GPU
type Gpu = Cuda<f32, i32>;

// Autodiff on CPU
use burn_autodiff::Autodiff;
type CpuDiff = Autodiff<Cpu>;
let device_cpu: Cpu::Device = Default::default();

Example: Random Tensor on GPU

let zeros = Gpu::zeros([4, 4], &device_gpu);
let rand  = Gpu::rand_uniform([4, 4], &device_gpu);
let out   = zeros.add(&rand);
Gpu::sync(&device_gpu);
let host: Vec<f32> = out.to_data().to_vec();

3.3 Modules

Modules encapsulate parameters and define a forward pass. Burn’s neural-network API lives under burn::nn and exports a prelude for common layers, losses, optimizers, schedulers, and utilities.

Defining a Model

use burn::prelude::*; // brings in Tensor, Module, Linear, MseLoss, Adam, Scheduler

#[derive(Module)]
struct SimpleRegressor {
    linear: Linear,
}

impl SimpleRegressor {
    fn new(input_dim: usize, output_dim: usize, device: &Cpu::Device) -> Self {
        let cfg = LinearConfig::new(input_dim, output_dim);
        Self { linear: cfg.init(device) }
    }
}

fn forward<M: Module<Tensor<NdArray, 2>>>(model: &M, input: Tensor<NdArray, 2>) -> Tensor<NdArray, 2> {
    model.forward(input)
}

Training Loop Snippet

let mut model = SimpleRegressor::new(10, 1, &device);
let mut optimizer = Adam::new(&model, AdamConfig::default());
let mut scheduler = StepLRScheduler::new(optimizer.lr(), 100, 0.1);

for epoch in 0..200 {
    let data = Tensor::from_floats([[...]], &device);
    let target = Tensor::from_floats([[...]], &device);

    let pred = model.forward(data.clone());
    let loss = MseLoss::new().forward(pred.clone(), target.clone());

    // Automatic differentiation (see next section)
    let mut grads = loss.backward();
    optimizer.update(&mut model, &mut grads);
    scheduler.step();
}

3.4 Automatic Differentiation

Burn’s AutodiffBackend (in burn-autodiff) instruments tensors for gradient tracking and provides:

  • require_grad()
  • backward() → Gradients
  • grad() / grad_remove()

Computing Gradients

use burn::tensor::{backend::{AutodiffBackend, Backend}, Tensor};

fn compute_grad<B: AutodiffBackend>(input: Tensor<B, 2>) -> B::Gradients {
    let x = input.require_grad();
    let loss = (x.clone() * x.clone()).sum(); // scalar
    let mut grads = loss.backward();

    // Peek or pop gradients for x
    let grad_x      = x.grad(&grads);
    let grad_x_popped = x.grad_remove(&mut grads);

    grads
}

Tips

  • Use grad() to inspect without consuming.
  • Use grad_remove() for in-place updates.
  • Backends without AutodiffBackend disable these methods at compile time.

4 Training Workflow Guide

This guide walks through an end-to-end recipe for training a MNIST classifier with Burn. You’ll cover configuration, data loading, model instantiation, optimizer/scheduler setup, learner construction, logging, checkpointing, and early stopping.

4.1 Configuration and Command-Line Arguments

Define a TrainConfig struct to capture hyperparameters and paths. Derive Config to parse from CLI flags or environment variables.

use burn::config::Config;

#[derive(Config, Debug)]
pub struct TrainConfig {
    #[config(default = 32)]
    pub batch_size: usize,

    #[config(default = 10)]
    pub epochs: usize,

    #[config(default = 0.01)]
    pub learning_rate: f64,

    #[config(default = 0.9)]
    pub momentum: f64,

    #[config(default = "data/mnist")]
    pub data_path: String,

    #[config(default = "checkpoints")]
    pub checkpoint_dir: String,
}

Load the config at runtime:

let config = TrainConfig::from_args();

4.2 Data Loading with DataLoader

Use Burn’s MnistDataset and DataLoader to prepare shuffled training and validation batches.

use burn::data::{DataLoader, dataset::Dataset};
use burn_dataset::mnist::MnistDataset;

// Load the raw MNIST splits
let train_ds = MnistDataset::new(&config.data_path)?.train();
let val_ds   = MnistDataset::new(&config.data_path)?.validation();

// Wrap in DataLoader for batching and shuffling
let train_loader = DataLoader::new(train_ds)
    .batch_size(config.batch_size)
    .shuffle(true)
    .num_workers(4);

let val_loader = DataLoader::new(val_ds)
    .batch_size(config.batch_size);

4.3 Model Definition

Instantiate your MNIST model and move it to the selected device (CPU/GPU).

use burn::tensor::backend::Backend;
use crate::model::MnistModel;

// B is your chosen backend, e.g. NdArrayBackend<f32>
let device = burn::tensor::device::Device::default();
let mut model: MnistModel<B> = MnistModel::new();
model = model.to_device(&device);

4.4 Optimizer and Learning Rate Scheduler

Configure SGD with momentum and a step-decay scheduler.

use burn::optim::{OptimizerConfig, SgdConfig};
use burn::lr_scheduler::{SchedulerConfig, StepSchedulerConfig};

// Initialize SGD optimizer
let optimizer = SgdConfig::new()
    .learning_rate(config.learning_rate)
    .momentum(config.momentum)
    .init(&model);

// Initialize step scheduler: decay LR by 0.1 every 5 epochs
let scheduler = StepSchedulerConfig::new()
    .step_size(5)
    .gamma(0.1)
    .init(&optimizer);

4.5 Constructing the Learner

Build a Learner to wire together model, optimizer, data loaders, scheduler, logging, checkpointing, and early stopping.

use burn::train::{
    Learner,
    LoggingConfig,
    CheckpointConfig,
    EarlyStoppingConfig,
};

// Create learner builder
let mut learner = Learner::builder(model, optimizer, train_loader)
    // Attach validation loader
    .with_validation(val_loader)
    // Plug in learning rate scheduler
    .with_scheduler(scheduler)
    // Enable console logging of metrics
    .with_logger(LoggingConfig::default())
    // Save checkpoints after each epoch
    .with_checkpoint(CheckpointConfig::new(&config.checkpoint_dir))
    // Stop training if validation loss hasn’t improved in 3 epochs
    .with_early_stopping(EarlyStoppingConfig::new().patience(3))
    .build();

4.6 Running Training

Kick off the training loop for the configured number of epochs. This performs forward/backward passes, updates parameters, logs metrics, and manages checkpoints.

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // ... load config, data, model, optimizer, scheduler, learner ...

    learner.fit(config.epochs).await?;

    Ok(())
}

4.7 Saving the Best Model

After training completes, export the best-performing model to disk.

// By default, learner saves the best checkpoint under `checkpoint_dir`
// To export a final TorchScript or ONNX model:
learner.save_model(&format!("{}/best_model.onnx", config.checkpoint_dir))?;

Practical Tips

• To resume training, point CheckpointConfig to an existing directory containing saved state.
• Adjust num_workers in DataLoader for your hardware to maximize data throughput.
• Swap in any OptimizerConfig (Adam, RMSProp) by replacing SgdConfig.
• Integrate custom callbacks (e.g., LR warm-up) by extending the Learner builder.
• For distributed training, configure Burn’s distributed backend and pass multiple devices to Learner::builder.

5 Backends & Performance

This section shows how to pick and configure GPU backends (CUDA, WGPU, CubeCL) in Burn, run tensor operations or training on them, and write custom kernels for specialized compute.

5.1 Enabling Backends in Cargo.toml

Control which backends compile by disabling default features and listing desired backends:

[dependencies]
burn = { version = "0.17", default-features = false, features = [
  "cuda",       # NVIDIA GPU via CUDA/cuDARC
  "wgpu",       # Cross-API GPU via WGPU (Vulkan/Metal/DirectX/OpenGL/WebGPU)
  "cubecl",     # JIT-compile backend for custom shader kernels
  "autodiff",   # automatic differentiation decorator
  "fusion"      # kernel fusion decorator
]}

5.2 CUDA Backend

Use the CUDA backend to run tensor ops on NVIDIA GPUs with automatic differentiation.

use burn::backend::{Cuda, cuda::CudaDevice};
use burn::tensor::{Tensor, Distribution};

fn main() {
    // Initialize CUDA on device 0
    let device = CudaDevice::discrete(0).unwrap();
    // Create two random matrices on GPU
    let a: Tensor<Cuda, 2> =
        Tensor::random([512, 512], Distribution::Default, &device).require_grad();
    let b = a.clone().relu();        // forward
    let loss = b.sum();              // scalar
    let grads = loss.backward();     // backprop
    let grad_a = a.grad(&grads).unwrap();
    println!("Sum: {:.3}, grad[0]: {:.3}", loss.to_host()[0], grad_a.to_host()[0]);
}

.require_grad() enables gradient tracking on Autodiff<Cuda>.

5.3 WGPU Backend

Use the WGPU backend for cross-platform GPU acceleration. Supports SPIR-V or WGSL kernels and task batching.

use burn::backend::{Wgpu, wgpu::WgpuConfig, wgpu::WgpuDevice};
use burn::tensor::{Tensor, Distribution};

#[tokio::main]
async fn main() {
    // Enable SPIR-V on Vulkan; fall back to WGSL otherwise
    let config = WgpuConfig { use_spirv: true, ..Default::default() };
    let device = WgpuDevice::new(config).await.unwrap();

    // Random tensor on GPU
    let x: Tensor<Wgpu, 2> =
        Tensor::random([256, 256], Distribution::Uniform, &device);
    let y = x.matmul(x).clamp(0.0, 1.0);
    println!("Output shape: {:?}", y.shape());
}

– Tune WgpuConfig fields (backends, concurrent_kernels) for your platform.

5.4 CubeCL Backend & Custom Kernels

The CubeCL backend JIT-compiles shader code (WGSL/GLSL/Metal) at runtime. Use it for writing custom kernels.

use burn_cubecl::{CubeclDevice, JitCompiler};
use burn_cubecl::kernel::Kernel;
use burn::tensor::Distribution;

fn main() -> anyhow::Result<()> {
    // Build a JIT compiler targeting WGSL
    let compiler = JitCompiler::builder().target("wgsl").build()?;
    let device = CubeclDevice::new(compiler)?;

    // A simple add-kernel in WGSL
    let src = r#"
        @kernel fn add(
            @buffer a: array<f32>;
            @buffer b: array<f32>;
            @buffer out: array<f32>
        ) {
            let i = global_id.x;
            out[i] = a[i] + b[i];
        }
    "#;
    let kernel: Kernel = device.compile(src, "add")?;

    // Prepare input data
    let len = 1 << 20;
    let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
    let b: Vec<f32> = vec![1.0; len];

    // Run on GPU
    let out = device.run(&kernel, &[&a, &b], len);
    // Verify
    assert_eq!(out[0], 0.0 + 1.0);
    assert_eq!(out[len - 1], (len - 1) as f32 + 1.0);
    Ok(())
}

JitCompiler::target(...) accepts "wgsl", "glsl", "metal" etc.
device.run(...) enqueues a 1D dispatch of size len.

5.5 Performance Tips

• Batch your workload to maximize GPU occupancy (e.g. process many samples in one dispatch).
• Enable fusion feature to merge element-wise ops into single kernels.
• On WGPU, prefer SPIR-V on Vulkan for lower compilation overhead.
• On CUDA, use cudarc flags (e.g. PTX version) via CudaDevice::with_ptx_version(...).
• Profile with vendor tools (nsight, RenderDoc) and tune workgroup sizes or thread-block shapes.

Model Import / Export

Leverage the burn-import and onnx-ir crates to ingest pre-trained ONNX, PyTorch, and Safetensors models into Burn. Convert ONNX models to Rust source via the onnx2burn CLI or ModelGen API, load PyTorch/Safetensors weights with dedicated recorders, and export Burn models for production using graph serialization or embedded states.

ModelGen: Converting ONNX Models to Burn Rust Code

Provide a CLI and builder API to translate one or more .onnx files into Rust sources implementing a Burn model.

1. Command-Line (onnx2burn)

# Convert model.onnx → generated/ (with debug dumps)
cargo run --bin onnx2burn -- \
  path/to/model.onnx    \  # input ONNX file
  generated/             \  # output directory
  --development &&       # emit .graph.txt + .onnx.txt debug files
  --half-precision       # save parameters as f16 where possible

Omit --development for minimal output. By default, uses full precision and pretty JSON.

2. Build Script (build.rs)

use burn_import::onnx::{ModelGen, RecordType};

fn main() {
    ModelGen::new()
        .input("assets/model.onnx")                 // queue ONNX file
        .out_dir("generated_models")                // relative to OUT_DIR
        .development(false)                         // no debug dumps
        .half_precision(true)                       // f16 weights
        .record_type(RecordType::CompactJson)       // compact JSON
        .embed_states(true)                         // inline blobs in .rs
        .run_from_script();                         // writes to ${OUT_DIR}
}

Generated files in ${OUT_DIR}/generated_models/:

  • model.rs — Burn model implementation
  • Optional .json or .bin blobs if embed_states == false
  • Debug dumps (.graph.txt, .onnx.txt) if development == true

API Reference

  • ModelGen::new(): initializes builder and logging.
  • .input(path: &str): add ONNX file.
  • .out_dir(path: &str): target directory.
  • .development(bool): emit human-readable dumps.
  • .half_precision(bool): toggle f16 vs f32/f64.
  • .record_type(RecordType): choose serialization (PrettyJson, CompactJson, NamedMpk, etc.).
  • .embed_states(bool): inline weight blobs.
  • .run_from_cli() vs .run_from_script(): CLI writes to CWD; script prepends Cargo’s OUT_DIR.

Tips

  • Use run_from_script() in CI/build pipelines.
  • Enable development to inspect ONNX and intermediate graphs.
  • Match runtime precision via half_precision.

Importing PyTorch Model Weights

Load .pt weight files into a ModelRecord and run inference with Burn.

1. Setup Backend & Weight Path

use burn::backend::NdArray;
type B = NdArray<f32>;

const WEIGHTS_FILE: &str = "weights/mnist.pt";

2. Create a Recorder

use burn::record::FullPrecisionSettings;
use burn_import::pytorch::PyTorchFileRecorder;

// Maps PyTorch f32 weights directly
let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();

3. Load Weights

use import_model_weights::{ModelRecord, infer};

let record: ModelRecord<B> = recorder
    .load(WEIGHTS_FILE.into(), &Default::default())
    .expect("Failed to load PyTorch model weights");

4. Run Inference

infer(record);

Full Example (src/bin/pytorch.rs)

use burn::backend::NdArray;
use burn::record::FullPrecisionSettings;
use burn_import::pytorch::PyTorchFileRecorder;
use import_model_weights::{ModelRecord, infer};

type B = NdArray<f32>;
const WEIGHTS_FILE: &str = "weights/mnist.pt";

fn main() {
    println!("Loading PyTorch model weights from file: {}", WEIGHTS_FILE);

    let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load(WEIGHTS_FILE.into(), &Default::default())
        .expect("Failed to load PyTorch model weights");

    infer(record);
}

Tips

  • Ensure .pt layer names & shapes match your Rust model.
  • Adjust WEIGHTS_FILE path relative to the binary’s CWD.
  • Implement custom RecorderSettings for mixed-precision setups.
  • Run via cargo run --bin pytorch from workspace root.

Persisting and Embedding Graph States

Serialize a BurnGraph’s learned parameters to disk or embed them into generated Rust code.

Signature

pub fn with_record(
    self,
    out_file: PathBuf,
    record_type: RecordType,
    embed_states: bool,
) -> Self

How It Works

  • Supports four formats: • PrettyJson
    NamedMpk
    NamedMpkGz
    Bincode (no-std friendly)
  • If embed_states == true, only Bincode is allowed. Emits:
    static EMBEDDED_STATES: &[u8] = include_bytes!("model_states.bin");
    
    and implements Default / Model::from_embedded(...).
  • Otherwise generates Default / Model::from_file(...) to load at runtime.

Panics if embed_states == true with non-Bincode RecordType.

Example

use burn_import::burn::{BurnGraph, RecordType};
use std::path::PathBuf;

// 1. Build your graph
let graph = BurnGraph::<MyPrecision>::default()
    .register_input_output(vec!["input".into()], vec!["output".into()])
    .register(MyConv2d::new(...))
    .register(MyReLU::new(...));

// 2a. Pretty JSON for debugging
let graph = graph.with_record(
    PathBuf::from("model_debug.json"),
    RecordType::PrettyJson,
    false,
);

// 2b. Compact MessagePack
let graph = graph.with_record(
    PathBuf::from("model_states.mpk"),
    RecordType::NamedMpk,
    false,
);

// 2c. Embed Bincode states
let graph = graph.with_record(
    PathBuf::from("model_states.bin"),
    RecordType::Bincode,
    true,
);

// 3. Generate Rust module
let tokens = graph.codegen();

Generated Helpers

File-based records:

impl<B: Backend> Default for Model<B> {
    fn default() -> Self {
        Self::from_file("model_states.mpk", &Default::default())
    }
}
impl<B: Backend> Model<B> {
    pub fn from_file(file: &str, device: &B::Device) -> Self { … }
}

Embedded Bincode:

static EMBEDDED_STATES: &[u8] = include_bytes!("model_states.bin");
impl<B: Backend> Default for Model<B> {
    fn default() -> Self {
        Self::from_embedded(&Default::default())
    }
}
impl<B: Backend> Model<B> {
    pub fn from_embedded(device: &B::Device) -> Self { … }
}

Tips

  • Use PrettyJson for human-readable inspection.
  • Choose NamedMpk/NamedMpkGz for fast native loads.
  • Embed with Bincode + embed_states = true for a self-contained no-std binary.
  • Always call register_input_output before with_record so inputs/outputs are known.

7 Examples Showcase

A curated collection of runnable examples demonstrating common Burn use-cases. Clone the tracel-ai/burn repository and run each example from the /examples folder.

7.1 MNIST Image Classification

Train a simple feed-forward network on MNIST.
File: examples/mnist_classification.rs

use burn::data::dataloader::DataLoader;
use burn::data::dataset::Mnist;
use burn::module::{Linear, Module, ReLU, Sequential};
use burn::optim::{AdamConfig, Adam};
use burn::tensor::{backend::ndarray::NdArrayBackend, Data, Tensor};
use burn::train::{Learner, Metric};

type Backend = NdArrayBackend<f32, Ix2>;
type Input = Tensor<Backend, 2>;
type Output = Tensor<Backend, 2>;

struct MLP {
    layers: Sequential<Backend>,
}

impl MLP {
    fn new() -> Self {
        let layers = Sequential::new(vec![
            Linear::new(784, 128),
            ReLU::new(),
            Linear::new(128, 10),
        ]);
        Self { layers }
    }
}

impl Module<Backend> for MLP {
    type Input = Input;
    type Output = Output;
    fn forward(&self, input: &Self::Input) -> Self::Output {
        let x = input.clone().reshape([input.shape()[0], 784]);
        self.layers.forward(x)
    }
}

fn main() {
    // Load MNIST dataset
    let dataset = Mnist::new("./data").unwrap();
    let loader = DataLoader::new(dataset.train(), 64);

    // Set up model and optimizer
    let mut model = MLP::new();
    let optimizer = AdamConfig::new().init();

    // Configure learner with accuracy metric
    let mut learner = Learner::new(&mut model, optimizer)
        .with_metric(Metric::accuracy());

    // Train for 5 epochs
    learner.fit(loader, 5);
}

7.2 Transfer Learning with ResNet50

Fine-tune a pre-trained ResNet50 on a custom dataset.
File: examples/transfer_resnet.rs

use burn::module::{resnet, Module};
use burn::data::dataloader::DataLoader;
use burn::data::dataset::ImageFolder;
use burn::optim::{AdamConfig, Adam};
use burn::tensor::backend::ndarray::NdArrayBackend;
use burn::train::Learner;

type Backend = NdArrayBackend<f32, Ix4>;
type Model = resnet::ResNet50<Backend>;

fn main() {
    // Load custom dataset (folder of class subdirectories)
    let dataset = ImageFolder::new("./data/cats_dogs").unwrap();
    let loader = DataLoader::new(dataset.train(), 32);

    // Load pre-trained ResNet50 and replace final layer
    let mut model = Model::pretrained().unwrap();
    model.replace_head(2048, 2); // two classes

    let optimizer = AdamConfig::new().learning_rate(1e-4).init();

    // Fine-tune for 3 epochs
    let mut learner = Learner::new(&mut model, optimizer);
    learner.fit(loader, 3);
}

7.3 Defining a Custom Layer

Implement a custom layer that scales inputs by a learnable parameter.
File: examples/custom_layer.rs

use burn::module::{Module, Param};
use burn::tensor::{backend::ndarray::NdArrayBackend, Tensor};

type Backend = NdArrayBackend<f32, Ix2>;
type Input = Tensor<Backend, 2>;
type Output = Tensor<Backend, 2>;

#[derive(Module)]
pub struct Scaling {
    scale: Param<Backend>,
}

impl Scaling {
    pub fn new(init: f32) -> Self {
        Self { scale: Param::<Backend>::from_scalar(init) }
    }
}

impl Module<Backend> for Scaling {
    type Input = Input;
    type Output = Output;
    fn forward(&self, input: &Self::Input) -> Self::Output {
        input * &self.scale.value()
    }
}

fn main() {
    let layer = Scaling::new(0.5);
    let x = Tensor::random([4, 4]);
    let y = layer.forward(&x);
    println!("Scaled output: {:?}", y);
}

7.4 Model Serialization and Inference

Save a trained model to disk and load it for inference.
File: examples/serialize_infer.rs

use burn::module::{Linear, Module, ReLU, Sequential};
use burn::tensor::backend::ndarray::NdArrayBackend;
use burn::tensor::Tensor;
use burn::serde::{save, load};

type Backend = NdArrayBackend<f32, Ix2>;
type Input = Tensor<Backend, 2>;
type Output = Tensor<Backend, 2>;

#[derive(Module)]
struct SimpleModel {
    net: Sequential<Backend>,
}

impl SimpleModel {
    fn new() -> Self {
        Self {
            net: Sequential::new(vec![
                Linear::new(10, 20),
                ReLU::new(),
                Linear::new(20, 5),
            ]),
        }
    }
}

fn main() {
    // Train or load
    let mut model = SimpleModel::new();
    // ... train your model here ...

    // Serialize
    save(&model, "model.bin").unwrap();

    // Later: load and run inference
    let loaded: SimpleModel = load("model.bin").unwrap();
    let input = Tensor::random([1, 10]);
    let output = loaded.forward(&input);
    println!("Inference output: {:?}", output);
}

7.5 Distributed Data Parallel Training

Scale training across multiple GPUs using NCCL.
File: examples/ddp_training.rs

use burn::distribute::{DistributedEnvironment, DDP};
use burn::data::dataloader::DataLoader;
use burn::data::dataset::Cifar10;
use burn::module::{Linear, Module, ReLU, Sequential};
use burn::optim::{AdamConfig};
use burn::train::Learner;
use burn::tensor::backend::cuda::CudaBackend;

type Backend = CudaBackend<f32>;
type Input = burn::tensor::Tensor<Backend, 4>;

#[derive(Module)]
struct CifarModel {
    net: Sequential<Backend>,
}

impl CifarModel {
    fn new() -> Self {
        Self {
            net: Sequential::new(vec![
                Linear::new(3 * 32 * 32, 256),
                ReLU::new(),
                Linear::new(256, 10),
            ]),
        }
    }
}

fn main() {
    // Initialize distributed environment (rank, world_size)
    let env = DistributedEnvironment::init().unwrap();
    let device = env.local_rank();

    // Wrap model in DDP
    let model = CifarModel::new().to_device(device);
    let ddp_model = DDP::new(model, &env);

    // Prepare data loader
    let dataset = Cifar10::new("./data").unwrap();
    let loader = DataLoader::new(dataset.train(), 128);

    // Set up optimizer
    let optimizer = AdamConfig::new().init();

    // Train across GPUs
    let mut learner = Learner::new(&mut ddp_model, optimizer);
    learner.fit(loader, 10);
}

Each example lives under examples/. Adjust paths, batch sizes, and hyperparameters to your environment.

8 Contributor Guide

Brief guidelines for contributing to the Burn repository: filing issues, exploring architecture, setting up your environment, following the development workflow, and using workspace automation tasks.

8.1 Filing Issues

Follow these steps to submit clear, actionable issues:

  1. Search existing issues and pull requests for similar reports.
  2. Click “New Issue” and choose the appropriate template (bug, feature, docs).
  3. Fill in:
    • Title: concise summary (e.g., “panic in tensor reshape with zero-dim”).
    • Description: steps to reproduce, input/output, expected vs. actual behavior.
    • Environment: OS, Rust version (rustc --version), Burn version.
    • Backtrace: include via RUST_BACKTRACE=1 cargo run-checks.
  4. Tag maintainers or relevant area labels (e.g., backend-torch, api).

8.2 Project Architecture Overview

Burn is organized as a Cargo workspace with modular crates:

  • burn-core: core tensor types, operations, macros
  • burn-tch: Torch (libtorch) backend integration
  • burn-ndarray: ndarray backend for CPU-only scenarios
  • burn-grad: automatic differentiation engine
  • xtask: custom CLI for builds, validation, benchmarking

Each crate exposes a clear API and re-exports core traits/types in the root crate (burn).

8.3 Environment Setup

  1. Install Rust (MSRV in rust-toolchain.toml):
    rustup install stable
    rustup default stable
    
  2. Clone and enter the repo:
    git clone https://github.com/tracel-ai/burn.git
    cd burn
    
  3. Install libtorch for the Torch backend (see crates/burn-tch/README.md):
    # macOS/Linux example
    export LIBTORCH=/path/to/libtorch
    export LD_LIBRARY_PATH=$LIBTORCH/lib:$LD_LIBRARY_PATH
    
  4. Confirm setup:
    cargo run-checks
    

8.4 Development Workflow

  1. Fork the burn repository and clone your fork.
  2. Create a descriptive branch:
    git checkout -b feature/reshape-zero-dim
    
  3. Implement changes, commit with Conventional Commits:
    feat(tensor): support zero-dimensional reshape
    
  4. Push branch and open a pull request against tracel-ai:main.
  5. Link related issue(s) in the PR description.

8.5 Pre-Pull Request Validation Checks

Ensure your changes conform to Burn’s formatting, linting, and validation rules:

Burn defines two aliases in .cargo/config.toml:

[alias]
xtask      = "run --target-dir target/xtask --color always --package xtask --bin xtask --"
run-checks = "xtask -c all validate --release"
  • cargo xtask
    Runs the custom xtask binary with consistent target directory and coloring.
  • cargo run-checks
    Shortcut for cargo xtask -c all validate --release, running all validation steps.

Usage:

# Run all checks
cargo run-checks

# Auto-fix formatting and minor lint issues
cargo xtask fix all

# Enable detailed macro backtraces
RUSTC_BOOTSTRAP=1 \
RUSTFLAGS="-Zmacro-backtrace" \
cargo run-checks

Practical Tips:

  • Run on a clean build to catch environment-sensitive errors.
  • Integrate cargo run-checks into pre-commit hooks or CI pipelines.
  • For Torch backend issues, consult crates/burn-tch/README.md#Installation.

8.6 Workspace Automation Tasks

The xtask crate provides helper commands beyond validation:

# Build documentation
cargo xtask doc

# Run benchmarks
cargo xtask bench --suite tensor_ops

# Generate version bump commit
cargo xtask version --bump patch
  • cargo xtask doc
    Builds and opens API docs for all crates.
  • cargo xtask bench
    Runs criterion benchmarks; pass --release for optimized runs.
  • cargo xtask version
    Updates CHANGELOG.md and bumps Cargo.toml versions.

Integrate these into your workflow to maintain consistency across releases and performance tests.