1 Project Overview
Burn is a high‐performance, flexible deep learning framework written in Rust. It combines Rust’s safety and concurrency guarantees with GPU‐ and hardware‐specific optimizations to support end‐to‐end model development—from research experiments to production deployments on CPU, GPU, WebAssembly, and embedded devices.
Core Values
- Performance
• Automatic kernel fusion
• Asynchronous execution engine
• Hardware‐specific backends (CUDA, OpenCL, Vulkan) - Flexibility
• Modular crates for tensors, layers, optimizers, schedulers
• Custom training loops and callbacks
• Import/export of standard formats (ONNX, TensorFlow) - Safety
• Rust’s memory and thread safety
• Zero‐cost abstractions for parallelism - Extensibility
• Pluggable backends (WASM, embedded accelerators)
• User‐defined layers and operations
Position in the ML Ecosystem
Burn sits at the intersection of research and production:
- Researchers leverage low‐level control for novel architectures and optimizations.
- Engineers deploy models across servers, browsers (WASM), and edge devices.
- System developers extend or integrate custom hardware backends.
High‐Level Repository Structure
Defined in Cargo.toml as a Cargo workspace, the repository organizes functionality into focused crates and examples:
/
├── Cargo.toml # Workspace config: members, dependencies, profiles
├── README.md # Project overview, getting started, contribute
├── CITATION.cff # Citation metadata for academic use
├── POEM.md # Inspirational overview of Burn’s vision
├── burn-book/ # Static site source for detailed guides
│ ├── src/
│ │ ├── overview.md
│ │ └── motivation.md
│ └── ...
├── crates/ # Core library crates
│ ├── burn-core/ # Core tensor types and operations
│ ├── burn-tensor/ # Multidimensional array backend
│ ├── burn-engine/ # Execution engine and schedulers
│ ├── burn-vision/ # Vision models and utilities
│ └── ...
└── examples/ # End-to-end training & inference scripts
├── cifar10-training/
├── wasm-inference/
└── ...
Workspace Configuration (Cargo.toml snippet)
Developers use this file to track crate versions, workspace policies, and build profiles.
[workspace]
members = [
"crates/burn-core",
"crates/burn-tensor",
"crates/burn-engine",
"crates/burn-vision",
"examples/*",
"burn-book"
]
[profile.release]
opt-level = "z"
lto = true
Use cargo build --workspace
to compile all crates, or target individual members with -p <crate-name>
. Continuous integration and dependency updates rely on this centralized configuration.
2 Getting Started
This guide takes you from zero to a running “hello-tensor” example and your first training job using Burn’s CI-tested examples. You’ll install Rust, enable backend features, leverage Cargo aliases, and run both the guide and MNIST examples.
2.1 Prerequisites
- Rust 1.70+ (via rustup)
- Optional GPU drivers for WGPU (Vulkan/Metal) and CUDA/Torch backends
- Git
2.2 Clone and Build Burn
git clone https://github.com/tracel-ai/burn.git
cd burn
# Build core library (CPU backend via ndarray)
cargo build --release
2.3 Configure Cargo Aliases
Burn provides aliases in .cargo/config.toml
to streamline common tasks:
[alias]
xtask = "run --package xtask --"
run-checks = "xtask run-checks --release"
cargo xtask <cmd>
runs custom scripts in thextask
cratecargo run-checks
executes formatting, linting, docs, tests and example runs
2.4 Enabling Backend Features
Burn supports multiple backends. Enable them via --features
:
# CPU (ndarray), Torch (CUDA), WGPU (Vulkan/Metal)
cargo build --release \
--features="backend-ndarray,backend-tch,backend-wgpu"
2.5 Hello-Tensor Example
Create a minimal Rust project that prints a tensor:
cargo new hello-tensor
cd hello-tensor
In Cargo.toml
, add Burn as a path dependency:
[dependencies]
burn = { path = "../burn", features = ["backend-ndarray"] }
Replace src/main.rs
with:
use burn::tensor::Tensor;
fn main() {
// Create a 2×2 tensor of f32
let t = Tensor::<f32, 2>::from([[1., 2.], [3., 4.]]);
println!("Hello, Tensor:\n{}", t);
}
Run:
cargo run --release
2.6 Running CI-Tested Examples
Burn’s repository includes two guided examples—guide
and mnist
—complete with training and inference workflows. Use cargo run-checks
to build and validate both on all enabled backends:
# From repo root: runs formatting, linting, docs, tests, and examples
cargo run-checks
2.6.1 Guide Example
cd examples/guide
# Train on CPU
cargo run -- train --device cpu --batch-size 32
# Inference on CPU
cargo run -- infer --device cpu --input path/to/image.png
# Print model summary
cargo run -- print-model
2.6.2 MNIST Example
cd examples/mnist
# Train with ndarray (CPU)
cargo run -- --backend ndarray --data ./data/mnist/train.csv --train
# Train with tch (CUDA)
cargo run -- --backend tch --data ./data/mnist/train.csv --train
This setup gives you a working Burn installation, a “hello-tensor” starter, and two fully CI-tested examples to explore training and inference across backends.
3 Core Concepts
This section introduces Burn’s fundamental abstractions: tensors, backends, modules, and automatic differentiation. Understanding these concepts helps you build and train models efficiently across different hardware.
3.1 Tensors
Burn’s Tensor
is the primary data container for n-dimensional arrays. It is generic over:
- A backend implementing
Backend
(CPU, GPU, autodiff). - A dimensionality
D
(compile-time). - An element type (
Float
,Int
, etc.). - A device (CPU, CUDA, …).
Creating Tensors
use burn::tensor::{Tensor, TensorData};
use burn::backend::NdArray; // CPU backend
use burn::tensor::dtype::Float;
// 1. Pick a device
let device = Default::default();
// 2. From primitive array (inferred dtype Float<f32>)
let t1: Tensor<NdArray, 2> = Tensor::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
// 3. From TensorData for custom dtype or slice
let raw = vec![10i32, 20, 30, 40];
let data = TensorData::from(raw.as_slice());
let t2: Tensor<NdArray, 2, _> = Tensor::from_data(data.reshape([2, 2]), &device);
// 4. Inspect and retrieve
let borrowed = t1.to_data(); // &TensorData
let owned = t1.clone().into_data(); // consumes tensor
Common Operations
// Element-wise math
let sum = t1.add(&t2); // t1 + t2
let scaled = t1.mul_scalar(0.5);
// Reductions
let total = sum.sum_all(); // f32
let max = sum.max_dim(1); // Tensor<…, 1>
3.2 Backends
Backends provide device-specific implementations of tensor operations via the Backend
trait. Burn ships with:
- NdArray (
burn::backend::NdArray
): CPU reference implementation. - Cuda (
burn_cuda::Cuda
): GPU via CubeCL. - Autodiff (
burn_autodiff::Autodiff<B>
): Wraps anyBackend
to enable gradient tracking.
Switching Backends
// CPU
use burn::backend::NdArray as Cpu;
// GPU
use burn_cuda::{Cuda, CudaDevice};
let device_gpu: CudaDevice = Default::default(); // picks first CUDA GPU
type Gpu = Cuda<f32, i32>;
// Autodiff on CPU
use burn_autodiff::Autodiff;
type CpuDiff = Autodiff<Cpu>;
let device_cpu: Cpu::Device = Default::default();
Example: Random Tensor on GPU
let zeros = Gpu::zeros([4, 4], &device_gpu);
let rand = Gpu::rand_uniform([4, 4], &device_gpu);
let out = zeros.add(&rand);
Gpu::sync(&device_gpu);
let host: Vec<f32> = out.to_data().to_vec();
3.3 Modules
Modules encapsulate parameters and define a forward
pass. Burn’s neural-network API lives under burn::nn
and exports a prelude for common layers, losses, optimizers, schedulers, and utilities.
Defining a Model
use burn::prelude::*; // brings in Tensor, Module, Linear, MseLoss, Adam, Scheduler
#[derive(Module)]
struct SimpleRegressor {
linear: Linear,
}
impl SimpleRegressor {
fn new(input_dim: usize, output_dim: usize, device: &Cpu::Device) -> Self {
let cfg = LinearConfig::new(input_dim, output_dim);
Self { linear: cfg.init(device) }
}
}
fn forward<M: Module<Tensor<NdArray, 2>>>(model: &M, input: Tensor<NdArray, 2>) -> Tensor<NdArray, 2> {
model.forward(input)
}
Training Loop Snippet
let mut model = SimpleRegressor::new(10, 1, &device);
let mut optimizer = Adam::new(&model, AdamConfig::default());
let mut scheduler = StepLRScheduler::new(optimizer.lr(), 100, 0.1);
for epoch in 0..200 {
let data = Tensor::from_floats([[...]], &device);
let target = Tensor::from_floats([[...]], &device);
let pred = model.forward(data.clone());
let loss = MseLoss::new().forward(pred.clone(), target.clone());
// Automatic differentiation (see next section)
let mut grads = loss.backward();
optimizer.update(&mut model, &mut grads);
scheduler.step();
}
3.4 Automatic Differentiation
Burn’s AutodiffBackend
(in burn-autodiff
) instruments tensors for gradient tracking and provides:
require_grad()
backward() → Gradients
grad()
/grad_remove()
Computing Gradients
use burn::tensor::{backend::{AutodiffBackend, Backend}, Tensor};
fn compute_grad<B: AutodiffBackend>(input: Tensor<B, 2>) -> B::Gradients {
let x = input.require_grad();
let loss = (x.clone() * x.clone()).sum(); // scalar
let mut grads = loss.backward();
// Peek or pop gradients for x
let grad_x = x.grad(&grads);
let grad_x_popped = x.grad_remove(&mut grads);
grads
}
Tips
- Use
grad()
to inspect without consuming. - Use
grad_remove()
for in-place updates. - Backends without
AutodiffBackend
disable these methods at compile time.
4 Training Workflow Guide
This guide walks through an end-to-end recipe for training a MNIST classifier with Burn. You’ll cover configuration, data loading, model instantiation, optimizer/scheduler setup, learner construction, logging, checkpointing, and early stopping.
4.1 Configuration and Command-Line Arguments
Define a TrainConfig
struct to capture hyperparameters and paths. Derive Config
to parse from CLI flags or environment variables.
use burn::config::Config;
#[derive(Config, Debug)]
pub struct TrainConfig {
#[config(default = 32)]
pub batch_size: usize,
#[config(default = 10)]
pub epochs: usize,
#[config(default = 0.01)]
pub learning_rate: f64,
#[config(default = 0.9)]
pub momentum: f64,
#[config(default = "data/mnist")]
pub data_path: String,
#[config(default = "checkpoints")]
pub checkpoint_dir: String,
}
Load the config at runtime:
let config = TrainConfig::from_args();
4.2 Data Loading with DataLoader
Use Burn’s MnistDataset
and DataLoader
to prepare shuffled training and validation batches.
use burn::data::{DataLoader, dataset::Dataset};
use burn_dataset::mnist::MnistDataset;
// Load the raw MNIST splits
let train_ds = MnistDataset::new(&config.data_path)?.train();
let val_ds = MnistDataset::new(&config.data_path)?.validation();
// Wrap in DataLoader for batching and shuffling
let train_loader = DataLoader::new(train_ds)
.batch_size(config.batch_size)
.shuffle(true)
.num_workers(4);
let val_loader = DataLoader::new(val_ds)
.batch_size(config.batch_size);
4.3 Model Definition
Instantiate your MNIST model and move it to the selected device (CPU/GPU).
use burn::tensor::backend::Backend;
use crate::model::MnistModel;
// B is your chosen backend, e.g. NdArrayBackend<f32>
let device = burn::tensor::device::Device::default();
let mut model: MnistModel<B> = MnistModel::new();
model = model.to_device(&device);
4.4 Optimizer and Learning Rate Scheduler
Configure SGD with momentum and a step-decay scheduler.
use burn::optim::{OptimizerConfig, SgdConfig};
use burn::lr_scheduler::{SchedulerConfig, StepSchedulerConfig};
// Initialize SGD optimizer
let optimizer = SgdConfig::new()
.learning_rate(config.learning_rate)
.momentum(config.momentum)
.init(&model);
// Initialize step scheduler: decay LR by 0.1 every 5 epochs
let scheduler = StepSchedulerConfig::new()
.step_size(5)
.gamma(0.1)
.init(&optimizer);
4.5 Constructing the Learner
Build a Learner
to wire together model, optimizer, data loaders, scheduler, logging, checkpointing, and early stopping.
use burn::train::{
Learner,
LoggingConfig,
CheckpointConfig,
EarlyStoppingConfig,
};
// Create learner builder
let mut learner = Learner::builder(model, optimizer, train_loader)
// Attach validation loader
.with_validation(val_loader)
// Plug in learning rate scheduler
.with_scheduler(scheduler)
// Enable console logging of metrics
.with_logger(LoggingConfig::default())
// Save checkpoints after each epoch
.with_checkpoint(CheckpointConfig::new(&config.checkpoint_dir))
// Stop training if validation loss hasn’t improved in 3 epochs
.with_early_stopping(EarlyStoppingConfig::new().patience(3))
.build();
4.6 Running Training
Kick off the training loop for the configured number of epochs. This performs forward/backward passes, updates parameters, logs metrics, and manages checkpoints.
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// ... load config, data, model, optimizer, scheduler, learner ...
learner.fit(config.epochs).await?;
Ok(())
}
4.7 Saving the Best Model
After training completes, export the best-performing model to disk.
// By default, learner saves the best checkpoint under `checkpoint_dir`
// To export a final TorchScript or ONNX model:
learner.save_model(&format!("{}/best_model.onnx", config.checkpoint_dir))?;
Practical Tips
• To resume training, point CheckpointConfig
to an existing directory containing saved state.
• Adjust num_workers
in DataLoader
for your hardware to maximize data throughput.
• Swap in any OptimizerConfig
(Adam, RMSProp) by replacing SgdConfig
.
• Integrate custom callbacks (e.g., LR warm-up) by extending the Learner
builder.
• For distributed training, configure Burn’s distributed backend and pass multiple devices to Learner::builder
.
5 Backends & Performance
This section shows how to pick and configure GPU backends (CUDA, WGPU, CubeCL) in Burn, run tensor operations or training on them, and write custom kernels for specialized compute.
5.1 Enabling Backends in Cargo.toml
Control which backends compile by disabling default features and listing desired backends:
[dependencies]
burn = { version = "0.17", default-features = false, features = [
"cuda", # NVIDIA GPU via CUDA/cuDARC
"wgpu", # Cross-API GPU via WGPU (Vulkan/Metal/DirectX/OpenGL/WebGPU)
"cubecl", # JIT-compile backend for custom shader kernels
"autodiff", # automatic differentiation decorator
"fusion" # kernel fusion decorator
]}
5.2 CUDA Backend
Use the CUDA backend to run tensor ops on NVIDIA GPUs with automatic differentiation.
use burn::backend::{Cuda, cuda::CudaDevice};
use burn::tensor::{Tensor, Distribution};
fn main() {
// Initialize CUDA on device 0
let device = CudaDevice::discrete(0).unwrap();
// Create two random matrices on GPU
let a: Tensor<Cuda, 2> =
Tensor::random([512, 512], Distribution::Default, &device).require_grad();
let b = a.clone().relu(); // forward
let loss = b.sum(); // scalar
let grads = loss.backward(); // backprop
let grad_a = a.grad(&grads).unwrap();
println!("Sum: {:.3}, grad[0]: {:.3}", loss.to_host()[0], grad_a.to_host()[0]);
}
– .require_grad()
enables gradient tracking on Autodiff<Cuda>
.
5.3 WGPU Backend
Use the WGPU backend for cross-platform GPU acceleration. Supports SPIR-V or WGSL kernels and task batching.
use burn::backend::{Wgpu, wgpu::WgpuConfig, wgpu::WgpuDevice};
use burn::tensor::{Tensor, Distribution};
#[tokio::main]
async fn main() {
// Enable SPIR-V on Vulkan; fall back to WGSL otherwise
let config = WgpuConfig { use_spirv: true, ..Default::default() };
let device = WgpuDevice::new(config).await.unwrap();
// Random tensor on GPU
let x: Tensor<Wgpu, 2> =
Tensor::random([256, 256], Distribution::Uniform, &device);
let y = x.matmul(x).clamp(0.0, 1.0);
println!("Output shape: {:?}", y.shape());
}
– Tune WgpuConfig
fields (backends
, concurrent_kernels
) for your platform.
5.4 CubeCL Backend & Custom Kernels
The CubeCL backend JIT-compiles shader code (WGSL/GLSL/Metal) at runtime. Use it for writing custom kernels.
use burn_cubecl::{CubeclDevice, JitCompiler};
use burn_cubecl::kernel::Kernel;
use burn::tensor::Distribution;
fn main() -> anyhow::Result<()> {
// Build a JIT compiler targeting WGSL
let compiler = JitCompiler::builder().target("wgsl").build()?;
let device = CubeclDevice::new(compiler)?;
// A simple add-kernel in WGSL
let src = r#"
@kernel fn add(
@buffer a: array<f32>;
@buffer b: array<f32>;
@buffer out: array<f32>
) {
let i = global_id.x;
out[i] = a[i] + b[i];
}
"#;
let kernel: Kernel = device.compile(src, "add")?;
// Prepare input data
let len = 1 << 20;
let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
let b: Vec<f32> = vec![1.0; len];
// Run on GPU
let out = device.run(&kernel, &[&a, &b], len);
// Verify
assert_eq!(out[0], 0.0 + 1.0);
assert_eq!(out[len - 1], (len - 1) as f32 + 1.0);
Ok(())
}
– JitCompiler::target(...)
accepts "wgsl"
, "glsl"
, "metal"
etc.
– device.run(...)
enqueues a 1D dispatch of size len
.
5.5 Performance Tips
• Batch your workload to maximize GPU occupancy (e.g. process many samples in one dispatch).
• Enable fusion
feature to merge element-wise ops into single kernels.
• On WGPU, prefer SPIR-V on Vulkan for lower compilation overhead.
• On CUDA, use cudarc
flags (e.g. PTX
version) via CudaDevice::with_ptx_version(...)
.
• Profile with vendor tools (nsight, RenderDoc) and tune workgroup sizes or thread-block shapes.
Model Import / Export
Leverage the burn-import
and onnx-ir
crates to ingest pre-trained ONNX, PyTorch, and Safetensors models into Burn. Convert ONNX models to Rust source via the onnx2burn
CLI or ModelGen
API, load PyTorch/Safetensors weights with dedicated recorders, and export Burn models for production using graph serialization or embedded states.
ModelGen: Converting ONNX Models to Burn Rust Code
Provide a CLI and builder API to translate one or more .onnx
files into Rust sources implementing a Burn model.
1. Command-Line (onnx2burn)
# Convert model.onnx → generated/ (with debug dumps)
cargo run --bin onnx2burn -- \
path/to/model.onnx \ # input ONNX file
generated/ \ # output directory
--development && # emit .graph.txt + .onnx.txt debug files
--half-precision # save parameters as f16 where possible
Omit --development
for minimal output. By default, uses full precision and pretty JSON.
2. Build Script (build.rs)
use burn_import::onnx::{ModelGen, RecordType};
fn main() {
ModelGen::new()
.input("assets/model.onnx") // queue ONNX file
.out_dir("generated_models") // relative to OUT_DIR
.development(false) // no debug dumps
.half_precision(true) // f16 weights
.record_type(RecordType::CompactJson) // compact JSON
.embed_states(true) // inline blobs in .rs
.run_from_script(); // writes to ${OUT_DIR}
}
Generated files in ${OUT_DIR}/generated_models/
:
model.rs
— Burn model implementation- Optional
.json
or.bin
blobs ifembed_states == false
- Debug dumps (
.graph.txt
,.onnx.txt
) ifdevelopment == true
API Reference
ModelGen::new()
: initializes builder and logging..input(path: &str)
: add ONNX file..out_dir(path: &str)
: target directory..development(bool)
: emit human-readable dumps..half_precision(bool)
: toggle f16 vs f32/f64..record_type(RecordType)
: choose serialization (PrettyJson, CompactJson, NamedMpk, etc.)..embed_states(bool)
: inline weight blobs..run_from_cli()
vs.run_from_script()
: CLI writes to CWD; script prepends Cargo’sOUT_DIR
.
Tips
- Use
run_from_script()
in CI/build pipelines. - Enable
development
to inspect ONNX and intermediate graphs. - Match runtime precision via
half_precision
.
Importing PyTorch Model Weights
Load .pt
weight files into a ModelRecord
and run inference with Burn.
1. Setup Backend & Weight Path
use burn::backend::NdArray;
type B = NdArray<f32>;
const WEIGHTS_FILE: &str = "weights/mnist.pt";
2. Create a Recorder
use burn::record::FullPrecisionSettings;
use burn_import::pytorch::PyTorchFileRecorder;
// Maps PyTorch f32 weights directly
let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();
3. Load Weights
use import_model_weights::{ModelRecord, infer};
let record: ModelRecord<B> = recorder
.load(WEIGHTS_FILE.into(), &Default::default())
.expect("Failed to load PyTorch model weights");
4. Run Inference
infer(record);
Full Example (src/bin/pytorch.rs)
use burn::backend::NdArray;
use burn::record::FullPrecisionSettings;
use burn_import::pytorch::PyTorchFileRecorder;
use import_model_weights::{ModelRecord, infer};
type B = NdArray<f32>;
const WEIGHTS_FILE: &str = "weights/mnist.pt";
fn main() {
println!("Loading PyTorch model weights from file: {}", WEIGHTS_FILE);
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(WEIGHTS_FILE.into(), &Default::default())
.expect("Failed to load PyTorch model weights");
infer(record);
}
Tips
- Ensure
.pt
layer names & shapes match your Rust model. - Adjust
WEIGHTS_FILE
path relative to the binary’s CWD. - Implement custom
RecorderSettings
for mixed-precision setups. - Run via
cargo run --bin pytorch
from workspace root.
Persisting and Embedding Graph States
Serialize a BurnGraph
’s learned parameters to disk or embed them into generated Rust code.
Signature
pub fn with_record(
self,
out_file: PathBuf,
record_type: RecordType,
embed_states: bool,
) -> Self
How It Works
- Supports four formats:
•
PrettyJson
•NamedMpk
•NamedMpkGz
•Bincode
(no-std friendly) - If
embed_states == true
, onlyBincode
is allowed. Emits:
and implementsstatic EMBEDDED_STATES: &[u8] = include_bytes!("model_states.bin");
Default
/Model::from_embedded(...)
. - Otherwise generates
Default
/Model::from_file(...)
to load at runtime.
Panics if embed_states == true
with non-Bincode RecordType
.
Example
use burn_import::burn::{BurnGraph, RecordType};
use std::path::PathBuf;
// 1. Build your graph
let graph = BurnGraph::<MyPrecision>::default()
.register_input_output(vec!["input".into()], vec!["output".into()])
.register(MyConv2d::new(...))
.register(MyReLU::new(...));
// 2a. Pretty JSON for debugging
let graph = graph.with_record(
PathBuf::from("model_debug.json"),
RecordType::PrettyJson,
false,
);
// 2b. Compact MessagePack
let graph = graph.with_record(
PathBuf::from("model_states.mpk"),
RecordType::NamedMpk,
false,
);
// 2c. Embed Bincode states
let graph = graph.with_record(
PathBuf::from("model_states.bin"),
RecordType::Bincode,
true,
);
// 3. Generate Rust module
let tokens = graph.codegen();
Generated Helpers
File-based records:
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_file("model_states.mpk", &Default::default())
}
}
impl<B: Backend> Model<B> {
pub fn from_file(file: &str, device: &B::Device) -> Self { … }
}
Embedded Bincode:
static EMBEDDED_STATES: &[u8] = include_bytes!("model_states.bin");
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_embedded(&Default::default())
}
}
impl<B: Backend> Model<B> {
pub fn from_embedded(device: &B::Device) -> Self { … }
}
Tips
- Use
PrettyJson
for human-readable inspection. - Choose
NamedMpk
/NamedMpkGz
for fast native loads. - Embed with
Bincode
+embed_states = true
for a self-contained no-std binary. - Always call
register_input_output
beforewith_record
so inputs/outputs are known.
7 Examples Showcase
A curated collection of runnable examples demonstrating common Burn use-cases. Clone the tracel-ai/burn repository and run each example from the /examples
folder.
7.1 MNIST Image Classification
Train a simple feed-forward network on MNIST.
File: examples/mnist_classification.rs
use burn::data::dataloader::DataLoader;
use burn::data::dataset::Mnist;
use burn::module::{Linear, Module, ReLU, Sequential};
use burn::optim::{AdamConfig, Adam};
use burn::tensor::{backend::ndarray::NdArrayBackend, Data, Tensor};
use burn::train::{Learner, Metric};
type Backend = NdArrayBackend<f32, Ix2>;
type Input = Tensor<Backend, 2>;
type Output = Tensor<Backend, 2>;
struct MLP {
layers: Sequential<Backend>,
}
impl MLP {
fn new() -> Self {
let layers = Sequential::new(vec![
Linear::new(784, 128),
ReLU::new(),
Linear::new(128, 10),
]);
Self { layers }
}
}
impl Module<Backend> for MLP {
type Input = Input;
type Output = Output;
fn forward(&self, input: &Self::Input) -> Self::Output {
let x = input.clone().reshape([input.shape()[0], 784]);
self.layers.forward(x)
}
}
fn main() {
// Load MNIST dataset
let dataset = Mnist::new("./data").unwrap();
let loader = DataLoader::new(dataset.train(), 64);
// Set up model and optimizer
let mut model = MLP::new();
let optimizer = AdamConfig::new().init();
// Configure learner with accuracy metric
let mut learner = Learner::new(&mut model, optimizer)
.with_metric(Metric::accuracy());
// Train for 5 epochs
learner.fit(loader, 5);
}
7.2 Transfer Learning with ResNet50
Fine-tune a pre-trained ResNet50 on a custom dataset.
File: examples/transfer_resnet.rs
use burn::module::{resnet, Module};
use burn::data::dataloader::DataLoader;
use burn::data::dataset::ImageFolder;
use burn::optim::{AdamConfig, Adam};
use burn::tensor::backend::ndarray::NdArrayBackend;
use burn::train::Learner;
type Backend = NdArrayBackend<f32, Ix4>;
type Model = resnet::ResNet50<Backend>;
fn main() {
// Load custom dataset (folder of class subdirectories)
let dataset = ImageFolder::new("./data/cats_dogs").unwrap();
let loader = DataLoader::new(dataset.train(), 32);
// Load pre-trained ResNet50 and replace final layer
let mut model = Model::pretrained().unwrap();
model.replace_head(2048, 2); // two classes
let optimizer = AdamConfig::new().learning_rate(1e-4).init();
// Fine-tune for 3 epochs
let mut learner = Learner::new(&mut model, optimizer);
learner.fit(loader, 3);
}
7.3 Defining a Custom Layer
Implement a custom layer that scales inputs by a learnable parameter.
File: examples/custom_layer.rs
use burn::module::{Module, Param};
use burn::tensor::{backend::ndarray::NdArrayBackend, Tensor};
type Backend = NdArrayBackend<f32, Ix2>;
type Input = Tensor<Backend, 2>;
type Output = Tensor<Backend, 2>;
#[derive(Module)]
pub struct Scaling {
scale: Param<Backend>,
}
impl Scaling {
pub fn new(init: f32) -> Self {
Self { scale: Param::<Backend>::from_scalar(init) }
}
}
impl Module<Backend> for Scaling {
type Input = Input;
type Output = Output;
fn forward(&self, input: &Self::Input) -> Self::Output {
input * &self.scale.value()
}
}
fn main() {
let layer = Scaling::new(0.5);
let x = Tensor::random([4, 4]);
let y = layer.forward(&x);
println!("Scaled output: {:?}", y);
}
7.4 Model Serialization and Inference
Save a trained model to disk and load it for inference.
File: examples/serialize_infer.rs
use burn::module::{Linear, Module, ReLU, Sequential};
use burn::tensor::backend::ndarray::NdArrayBackend;
use burn::tensor::Tensor;
use burn::serde::{save, load};
type Backend = NdArrayBackend<f32, Ix2>;
type Input = Tensor<Backend, 2>;
type Output = Tensor<Backend, 2>;
#[derive(Module)]
struct SimpleModel {
net: Sequential<Backend>,
}
impl SimpleModel {
fn new() -> Self {
Self {
net: Sequential::new(vec![
Linear::new(10, 20),
ReLU::new(),
Linear::new(20, 5),
]),
}
}
}
fn main() {
// Train or load
let mut model = SimpleModel::new();
// ... train your model here ...
// Serialize
save(&model, "model.bin").unwrap();
// Later: load and run inference
let loaded: SimpleModel = load("model.bin").unwrap();
let input = Tensor::random([1, 10]);
let output = loaded.forward(&input);
println!("Inference output: {:?}", output);
}
7.5 Distributed Data Parallel Training
Scale training across multiple GPUs using NCCL.
File: examples/ddp_training.rs
use burn::distribute::{DistributedEnvironment, DDP};
use burn::data::dataloader::DataLoader;
use burn::data::dataset::Cifar10;
use burn::module::{Linear, Module, ReLU, Sequential};
use burn::optim::{AdamConfig};
use burn::train::Learner;
use burn::tensor::backend::cuda::CudaBackend;
type Backend = CudaBackend<f32>;
type Input = burn::tensor::Tensor<Backend, 4>;
#[derive(Module)]
struct CifarModel {
net: Sequential<Backend>,
}
impl CifarModel {
fn new() -> Self {
Self {
net: Sequential::new(vec![
Linear::new(3 * 32 * 32, 256),
ReLU::new(),
Linear::new(256, 10),
]),
}
}
}
fn main() {
// Initialize distributed environment (rank, world_size)
let env = DistributedEnvironment::init().unwrap();
let device = env.local_rank();
// Wrap model in DDP
let model = CifarModel::new().to_device(device);
let ddp_model = DDP::new(model, &env);
// Prepare data loader
let dataset = Cifar10::new("./data").unwrap();
let loader = DataLoader::new(dataset.train(), 128);
// Set up optimizer
let optimizer = AdamConfig::new().init();
// Train across GPUs
let mut learner = Learner::new(&mut ddp_model, optimizer);
learner.fit(loader, 10);
}
Each example lives under examples/
. Adjust paths, batch sizes, and hyperparameters to your environment.
8 Contributor Guide
Brief guidelines for contributing to the Burn repository: filing issues, exploring architecture, setting up your environment, following the development workflow, and using workspace automation tasks.
8.1 Filing Issues
Follow these steps to submit clear, actionable issues:
- Search existing issues and pull requests for similar reports.
- Click “New Issue” and choose the appropriate template (bug, feature, docs).
- Fill in:
- Title: concise summary (e.g., “panic in tensor reshape with zero-dim”).
- Description: steps to reproduce, input/output, expected vs. actual behavior.
- Environment: OS, Rust version (
rustc --version
), Burn version. - Backtrace: include via
RUST_BACKTRACE=1 cargo run-checks
.
- Tag maintainers or relevant area labels (e.g.,
backend-torch
,api
).
8.2 Project Architecture Overview
Burn is organized as a Cargo workspace with modular crates:
- burn-core: core tensor types, operations, macros
- burn-tch: Torch (libtorch) backend integration
- burn-ndarray: ndarray backend for CPU-only scenarios
- burn-grad: automatic differentiation engine
- xtask: custom CLI for builds, validation, benchmarking
Each crate exposes a clear API and re-exports core traits/types in the root crate (burn
).
8.3 Environment Setup
- Install Rust (MSRV in
rust-toolchain.toml
):rustup install stable rustup default stable
- Clone and enter the repo:
git clone https://github.com/tracel-ai/burn.git cd burn
- Install libtorch for the Torch backend (see
crates/burn-tch/README.md
):# macOS/Linux example export LIBTORCH=/path/to/libtorch export LD_LIBRARY_PATH=$LIBTORCH/lib:$LD_LIBRARY_PATH
- Confirm setup:
cargo run-checks
8.4 Development Workflow
- Fork the
burn
repository and clone your fork. - Create a descriptive branch:
git checkout -b feature/reshape-zero-dim
- Implement changes, commit with Conventional Commits:
feat(tensor): support zero-dimensional reshape
- Push branch and open a pull request against
tracel-ai:main
. - Link related issue(s) in the PR description.
8.5 Pre-Pull Request Validation Checks
Ensure your changes conform to Burn’s formatting, linting, and validation rules:
Burn defines two aliases in .cargo/config.toml
:
[alias]
xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"
run-checks = "xtask -c all validate --release"
cargo xtask
Runs the customxtask
binary with consistent target directory and coloring.cargo run-checks
Shortcut forcargo xtask -c all validate --release
, running all validation steps.
Usage:
# Run all checks
cargo run-checks
# Auto-fix formatting and minor lint issues
cargo xtask fix all
# Enable detailed macro backtraces
RUSTC_BOOTSTRAP=1 \
RUSTFLAGS="-Zmacro-backtrace" \
cargo run-checks
Practical Tips:
- Run on a clean build to catch environment-sensitive errors.
- Integrate
cargo run-checks
into pre-commit hooks or CI pipelines. - For Torch backend issues, consult
crates/burn-tch/README.md#Installation
.
8.6 Workspace Automation Tasks
The xtask
crate provides helper commands beyond validation:
# Build documentation
cargo xtask doc
# Run benchmarks
cargo xtask bench --suite tensor_ops
# Generate version bump commit
cargo xtask version --bump patch
cargo xtask doc
Builds and opens API docs for all crates.cargo xtask bench
Runs criterion benchmarks; pass--release
for optimized runs.cargo xtask version
UpdatesCHANGELOG.md
and bumps Cargo.toml versions.
Integrate these into your workflow to maintain consistency across releases and performance tests.