DEV Community

Cover image for Transitioning From PyTorch to Burn
Guillaume Lagrange
Guillaume Lagrange

Posted on • Updated on

Transitioning From PyTorch to Burn

Deep learning development requires very high-level abstractions as well as extremely fast execution time, and this is exactly where Burn shines. Burn is a comprehensive deep learning framework in Rust which focuses on flexibility, compute efficiency and portability. It can easily be used with different backends to support many devices and use cases.

Implementing your model with Burn is pretty straightforward thanks to the provided building blocks. But what if you already have a model that you trained with PyTorch and you'd like to use it with Burn? Well, we've got you covered.

In this tutorial, we'll use the popular ResNet family of models as a reference. We'll learn how to:

  1. Define a neural network with Module
  2. Import PyTorch weights with the PyTorchFileRecorder
  3. Save and load the model weights with a supported recorder format
  4. Load and normalize an image for inference
  5. Perform inference with the model and check the results

The code used in this tutorial is available on GitHub.

Update

[15/04/2024] The tutorial has been updated to use the released version of Burn 0.13.0.

Module Definition

Let's start by defining the ResNet module according to the Residual Network architecture, as replicated[1] by the torchvision implementation of the model we will import. Detailed architecture variants with a depth of 18, 34, 50, 101 and 152 layers can be found in the table below.

ResNet architectures

The 18-layer and 34-layer architectures use residual blocks with shortcut connections (also known as skip connections). A residual block has two 3×33 \times 3 convolutional layers with the same number of output channels. Both convolutional layers are followed by a batch normalization layer. A ReLU activation function follows the first convolution and batch normalization block, but then the shortcut connection is added to the output of the second convolution block before the final ReLU activation. Evidently, this addition operation requires that the output of the second convolutional layer matches the shape of the input. For blocks where that is not the case, a projection shortcut is used instead with a 1×11 \times 1 convolution to ensure both inputs to the addition are the same size. As with other convolutional layers in the network, a batch normalization layer follows.

Basic block
(a) Residual block
Bottleneck block
(b) Bottleneck block
Projection shortcut
(c) Projection shortcut

For deeper ResNets, the building blocks are a bit different. The so-called bottleneck residual block follows a similar structure to the basic residual block, but uses 1×11 \times 1 convolutions to create a bottleneck that reduces the number of parameters and matrix multiplications. The first 1×11 \times 1 convolutions are responsible for reducing and then increasing (restoring) dimensions, which leaves the 3×33 \times 3 convolution a bottleneck with smaller input/output dimensions.

To form the resulting network architecture, a number of building blocks are stacked as detailed previously in the architecture table. For example, the ResNet-18 architecture is depicted below.

ResNet-18 architecture

Let's have a look at how that translates into code[2].

To follow along the Burn implementation, we recommend you split the code in this tutorial into multiple modules. We will have a library crate where the different modules required for our ResNet implementation will reside as well as a main.rs that will be used as the de facto executable for our project. The project structure looks like this:

tutorial
├── Cargo.toml
└── src
    ├── lib.rs          // library
    ├── main.rs         // binary
    └── model
        ├── block.rs    // Residual block module definition
        ├── imagenet.rs // ImageNet utilities
        ├── mod.rs      // Module declaration for model
        └── resnet.rs   // ResNet module definition
Enter fullscreen mode Exit fullscreen mode

We'll provide the required imports for each module as we go.

The content of your Cargo.toml should include the dependencies for this tutorial.

[package]
name = "resnet_burn"
version = "0.2.0"
edition = "2021"

[dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
burn-import = { version = "0.13.0" }
image = { version = "0.24.9", features = ["png", "jpeg"] }
Enter fullscreen mode Exit fullscreen mode

Don't forget to add the modules to the model module and make model public in the library crate.

// File: model/mod.rs

mod block;
pub mod imagenet;
pub mod resnet;
Enter fullscreen mode Exit fullscreen mode
// File: lib.rs

pub mod model;
Enter fullscreen mode Exit fullscreen mode

At this point the project won't compile just yet since the main function hasn't been defined. You can initialize main.rs with an empty main function for now and we'll complete it later in the tutorial.

// File: main.rs

fn main() {}
Enter fullscreen mode Exit fullscreen mode

Alright, we can look at the code for real now!

We'll first present the reference PyTorch implementation to better demonstrate the parallel to Burn when porting your model. Let's start by defining the basic and bottleneck residual blocks which will be used for the different ResNet variants.

import torch
from torch import nn, Tensor
from typing import List, Optional, Type, Union

def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
Enter fullscreen mode Exit fullscreen mode

As you are likely aware, PyTorch uses modules (nn.Module) to represent neural networks. In the code above we defined custom modules for BasicBlock and Bottleneck building blocks. Creating custom neural network modules in Burn follows a similar approach with the Module trait. In practice, you'll use the #[derive(Module)] attribute on top of a struct and Burn will automatically handle the Module trait implementation. While most module implementations in Burn define a forward method for the forward pass, this is only by convention. The derive function makes no assumption about how the forward pass is declared (and thus, the method could have any name you'd like).

model/block.rs

The residual block module requires the following imports.

use core::marker::PhantomData;

use burn::{
    module::Module,
    nn::{
        conv::{Conv2d, Conv2dConfig},
        BatchNorm, BatchNormConfig, PaddingConfig2d, Relu,
    },
    tensor::{backend::Backend, Device, Tensor},
};
Enter fullscreen mode Exit fullscreen mode

Here is what the BasicBlock and Bottleneck definitions look like in Burn:

// File: model/block.rs

/// ResNet [basic residual block](https://paperswithcode.com/method/residual-block) implementation.
#[derive(Module, Debug)]
pub struct BasicBlock<B: Backend> {
    conv1: Conv2d<B>,
    bn1: BatchNorm<B, 2>,
    relu: Relu,
    conv2: Conv2d<B>,
    bn2: BatchNorm<B, 2>,
    downsample: Option<Downsample<B>>,
}

/// ResNet [bottleneck residual block](https://paperswithcode.com/method/bottleneck-residual-block)
/// implementation.
#[derive(Module, Debug)]
pub struct Bottleneck<B: Backend> {
    conv1: Conv2d<B>,
    bn1: BatchNorm<B, 2>,
    relu: Relu,
    conv2: Conv2d<B>,
    bn2: BatchNorm<B, 2>,
    conv3: Conv2d<B>,
    bn3: BatchNorm<B, 2>,
    downsample: Option<Downsample<B>>,
}

/// Downsample layer applies a 1x1 conv to reduce the resolution (H, W) and adjust the number of channels.
#[derive(Module, Debug)]
pub struct Downsample<B: Backend> {
    conv: Conv2d<B>,
    bn: BatchNorm<B, 2>,
}
Enter fullscreen mode Exit fullscreen mode

You might have noticed something special already. Every module struct is generic over the Backend trait: struct MyModule<B: Backend>. The backend trait abstracts the underlying low level implementations of tensor operations, allowing your new model to run on any backend.

Obviously, the code above only defines the struct as a parameter container. We still have to define the module initialization and forward pass. Let's start with the Downsample module as it is the simplest.

In torchvision the implementation is encapsulated in the _make_layer() method that will be detailed later.

downsample = nn.Sequential(
    conv1x1(self.inplanes, planes * block.expansion, stride),
    nn.BatchNorm2d(planes * block.expansion),
)
Enter fullscreen mode Exit fullscreen mode

In Burn, we need to define a specific type, which is why we added the Downsample struct above. Now, let's implement the initialization and forward pass for this module as below.

// File: model/block.rs

impl<B: Backend> Downsample<B> {
    pub fn new(in_channels: usize, out_channels: usize, stride: usize, device: &Device<B>) -> Self {
        // conv1x1
        let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1])
            .with_stride([stride, stride])
            .with_padding(PaddingConfig2d::Explicit(0, 0))
            .with_bias(false)
            .init(device);
        let bn = BatchNormConfig::new(out_channels).init(device);

        Self { conv, bn }
    }

    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let out = self.conv.forward(input);
        self.bn.forward(out)
    }
}
Enter fullscreen mode Exit fullscreen mode

In order to have a ResNet module that will work for all architecture variants (i.e., for both BasicBlock and Bottleneck), we'll define a simple ResidualBlock trait that will be implemented for both blocks.

// File: model/block.rs

/// Residual blocks are the building blocks of residual networks (ResNets).
/// The blocks use skip connections to learn residual functions with reference to the layer inputs.
pub trait ResidualBlock<B: Backend> {
    fn init(in_channels: usize, out_channels: usize, stride: usize, device: &Device<B>) -> Self;
    fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;
}

impl<B: Backend> ResidualBlock<B> for BasicBlock<B> {
    fn init(in_channels: usize, out_channels: usize, stride: usize, device: &Device<B>) -> Self {
        // conv3x3
        let conv1 = Conv2dConfig::new([in_channels, out_channels], [3, 3])
            .with_stride([stride, stride])
            .with_padding(PaddingConfig2d::Explicit(1, 1))
            .with_bias(false)
            .init(device);
        let bn1 = BatchNormConfig::new(out_channels).init(device);
        let relu = Relu::new();
        // conv3x3
        let conv2 = Conv2dConfig::new([out_channels, out_channels], [3, 3])
            .with_stride([1, 1])
            .with_padding(PaddingConfig2d::Explicit(1, 1))
            .with_bias(false)
            .init(device);
        let bn2 = BatchNormConfig::new(out_channels).init(device);

        let downsample = {
            if in_channels != out_channels {
                Some(Downsample::new(in_channels, out_channels, stride, device))
            } else {
                None
            }
        };

        Self {
            conv1,
            bn1,
            relu,
            conv2,
            bn2,
            downsample,
        }
    }

    fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let identity = input.clone();

        // Conv block
        let out = self.conv1.forward(input);
        let out = self.bn1.forward(out);
        let out = self.relu.forward(out);
        let out = self.conv2.forward(out);
        let out = self.bn2.forward(out);

        // Skip connection
        let out = {
            match &self.downsample {
                Some(downsample) => out + downsample.forward(identity),
                None => out + identity,
            }
        };

        // Activation
        self.relu.forward(out)
    }
}

impl<B: Backend> ResidualBlock<B> for Bottleneck<B> {
    fn init(in_channels: usize, out_channels: usize, stride: usize, device: &Device<B>) -> Self {
        // Intermediate output channels w/ expansion = 4
        let int_out_channels = out_channels / 4;
        // conv1x1
        let conv1 = Conv2dConfig::new([in_channels, int_out_channels], [1, 1])
            .with_stride([1, 1])
            .with_padding(PaddingConfig2d::Explicit(0, 0))
            .with_bias(false)
            .init(device);
        let bn1 = BatchNormConfig::new(int_out_channels).init(device);
        let relu = Relu::new();
        // conv3x3
        let conv2 = Conv2dConfig::new([int_out_channels, int_out_channels], [3, 3])
            .with_stride([stride, stride])
            .with_padding(PaddingConfig2d::Explicit(1, 1))
            .with_bias(false)
            .init(device);
        let bn2 = BatchNormConfig::new(int_out_channels).init(device);
        // conv1x1
        let conv3 = Conv2dConfig::new([int_out_channels, out_channels], [1, 1])
            .with_stride([1, 1])
            .with_padding(PaddingConfig2d::Explicit(0, 0))
            .with_bias(false)
            .init(device);
        let bn3 = BatchNormConfig::new(out_channels).init(device);

        let downsample = {
            if in_channels != out_channels {
                Some(Downsample::new(in_channels, out_channels, stride, device))
            } else {
                None
            }
        };

        Self {
            conv1,
            bn1,
            relu,
            conv2,
            bn2,
            conv3,
            bn3,
            downsample,
        }
    }

    fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let identity = input.clone();

        // Conv block
        let out = self.conv1.forward(input);
        let out = self.bn1.forward(out);
        let out = self.relu.forward(out);
        let out = self.conv2.forward(out);
        let out = self.bn2.forward(out);
        let out = self.relu.forward(out);
        let out = self.conv3.forward(out);
        let out = self.bn3.forward(out);

        // Skip connection
        let out = {
            match &self.downsample {
                Some(downsample) => out + downsample.forward(identity),
                None => out + identity,
            }
        };

        // Activation
        self.relu.forward(out)
    }
}
Enter fullscreen mode Exit fullscreen mode

In doing so, we can encapsulate a sequence of residual blocks in a vector, similar to the role of nn.Sequential in _make_layer(). We'll call this module a LayerBlock, which is generic over the module M, and restrict our implementation to the ResidualBlock trait we defined earlier.

// File: model/block.rs

/// Collection of sequential residual blocks.
#[derive(Module, Debug)]
pub struct LayerBlock<B: Backend, M> {
    blocks: Vec<M>,
    _backend: PhantomData<B>,
}

impl<B: Backend, M: ResidualBlock<B>> LayerBlock<B, M> {
    pub fn new(
        num_blocks: usize,
        in_channels: usize,
        out_channels: usize,
        stride: usize,
        device: &Device<B>,
    ) -> Self {
        let blocks = (0..num_blocks)
            .map(|b| {
                if b == 0 {
                    // First block uses the specified stride
                    M::init(in_channels, out_channels, stride, device)
                } else {
                    // Other blocks use a stride of 1
                    M::init(out_channels, out_channels, 1, device)
                }
            })
            .collect();

        Self {
            blocks,
            _backend: PhantomData,
        }
    }

    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let mut out = input;
        for block in &self.blocks {
            out = block.forward(out);
        }
        out
    }
}
Enter fullscreen mode Exit fullscreen mode

You probably noticed a peculiarity of our LayerBlock definition: the PhantomData marker. This fake field was simply added to capture the generic trait bound over the backend B.

Now that all of the required building blocks are defined, we can implement the final ResNet architecture. In torchvision, an instance of the architecture is created by specifying the block type and a layers list with the number of blocks to add to the network.

class ResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
    ) -> None:
        super().__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
    ) -> nn.Sequential:
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride,
                downsample,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
Enter fullscreen mode Exit fullscreen mode

For example, the different network variants can be instantiated via the following functions.

def resnet18(num_classes: int) -> ResNet:
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def resnet34(num_classes: int) -> ResNet:
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)

def resnet50(num_classes: int) -> ResNet:
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)

def resnet101(num_classes: int) -> ResNet:
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)

def resnet152(num_classes: int) -> ResNet:
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)
Enter fullscreen mode Exit fullscreen mode

In Burn, we can instead use a trait bound for our ResidualBlock to be specified for the concrete implementations.

model/resnet.rs

The ResNet module requires the following imports.

use burn::{
    module::Module,
    nn::{
        conv::{Conv2d, Conv2dConfig},
        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig, MaxPool2d, MaxPool2dConfig},
        BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
    },
    tensor::{backend::Backend, Device, Tensor},
};

use super::block::{BasicBlock, Bottleneck, LayerBlock, ResidualBlock};
Enter fullscreen mode Exit fullscreen mode

With that in mind, here's what the ResNet definition looks like.

// File: model/resnet.rs

#[derive(Module, Debug)]
pub struct ResNet<B: Backend, M> {
    conv1: Conv2d<B>,
    bn1: BatchNorm<B, 2>,
    relu: Relu,
    maxpool: MaxPool2d,
    layer1: LayerBlock<B, M>,
    layer2: LayerBlock<B, M>,
    layer3: LayerBlock<B, M>,
    layer4: LayerBlock<B, M>,
    avgpool: AdaptiveAvgPool2d,
    fc: Linear<B>,
}

impl<B: Backend, M: ResidualBlock<B>> ResNet<B, M> {
    fn new(blocks: [usize; 4], num_classes: usize, expansion: usize, device: &Device<B>) -> Self {
        // 7x7 conv, 64, /2
        let conv1 = Conv2dConfig::new([3, 64], [7, 7])
            .with_stride([2, 2])
            .with_padding(PaddingConfig2d::Explicit(3, 3))
            .with_bias(false)
            .init(device);
        let bn1 = BatchNormConfig::new(64).init(device);
        let relu = Relu::new();
        // 3x3 maxpool, /2
        let maxpool = MaxPool2dConfig::new([3, 3])
            .with_strides([2, 2])
            .with_padding(PaddingConfig2d::Explicit(1, 1))
            .init();

        // Residual blocks
        let layer1 = LayerBlock::new(blocks[0], 64, 64 * expansion, 1, device);
        let layer2 = LayerBlock::new(blocks[1], 64 * expansion, 128 * expansion, 2, device);
        let layer3 = LayerBlock::new(blocks[2], 128 * expansion, 256 * expansion, 2, device);
        let layer4 = LayerBlock::new(blocks[3], 256 * expansion, 512 * expansion, 2, device);

        // Average pooling [B, 512 * expansion, H, W] -> [B, 512 * expansion, 1, 1]
        let avgpool = AdaptiveAvgPool2dConfig::new([1, 1]).init();

        // Output layer
        let fc = LinearConfig::new(512 * expansion, num_classes).init(device);

        Self {
            conv1,
            bn1,
            relu,
            maxpool,
            layer1,
            layer2,
            layer3,
            layer4,
            avgpool,
            fc,
        }
    }

    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 2> {
        // First block
        let out = self.conv1.forward(input);
        let out = self.bn1.forward(out);
        let out = self.relu.forward(out);
        let out = self.maxpool.forward(out);

        // Residual blocks
        let out = self.layer1.forward(out);
        let out = self.layer2.forward(out);
        let out = self.layer3.forward(out);
        let out = self.layer4.forward(out);

        let out = self.avgpool.forward(out);
        // Reshape [B, C, 1, 1] -> [B, C]
        let out = out.flatten(1, 3);

        self.fc.forward(out)
    }
}
Enter fullscreen mode Exit fullscreen mode

Then, we implement the different methods to instantiate the network variants. Since we didn't mark the new() method as public, the only public interface to instantiate a concrete ResNet module is via the methods defined below.

// File: model/resnet.rs

impl<B: Backend> ResNet<B, BasicBlock<B>> {
    pub fn resnet18(num_classes: usize, device: &Device<B>) -> Self {
        Self::new([2, 2, 2, 2], num_classes, 1, device)
    }

    pub fn resnet34(num_classes: usize, device: &Device<B>) -> Self {
        Self::new([3, 4, 6, 3], num_classes, 1, device)
    }
}

impl<B: Backend> ResNet<B, Bottleneck<B>> {
    pub fn resnet50(num_classes: usize, device: &Device<B>) -> Self {
        Self::new([3, 4, 6, 3], num_classes, 4, device)
    }

    pub fn resnet101(num_classes: usize, device: &Device<B>) -> Self {
        Self::new([3, 4, 23, 3], num_classes, 4, device)
    }

    pub fn resnet152(num_classes: usize, device: &Device<B>) -> Self {
        Self::new([3, 8, 36, 3], num_classes, 4, device)
    }
}
Enter fullscreen mode Exit fullscreen mode

Congratulations, you have now defined the infamous ResNet architecture in Burn!

Let's take a look at how you can import the pre-trained PyTorch weights in Burn next.

Importing Pre-trained PyTorch Weights

Burn just recently added support to import PyTorch weights directly into your Burn module definition, so let's take advantage of that!

As you may know, PyTorch saves the weights as a (flattened) state dictionary that maps each layer name to its parameter tensor. To load the weights correctly, the same layer names must be found in your module definition.

If the structure or names don't match up exactly, do not worry. The PyTorchFileRecorder provides a way to remap layer names to your new definition. You probably already noticed that our layer names aren't a 1-to-1 match with the PyTorch definition, so we will actually be using this utility in this tutorial.

For this example, we'll use the ResNet-18 pre-trained weights from torchvision. You can download the weights from this url.

The first discrepancy with our module definition is that the projection shortcut (i.e., Downsample) block is explicitly defined, whereas in the torchvision implementation it is added "on-the-fly" as a nn.Sequential containing a convolutional layer and a batch normalization layer. Concretely, we want to map any downsample.0 instance to a downsample.conv and any downsample.1 to a downsample.bn. For that, we can use the with_key_remap method provided by the PyTorchFileRecorder's load arguments which supports key replacement via a regular expression.

For the downsample block re-mapping, it would look something like this:

let load_args = LoadArgs::new(weights_path)
    // Map *.downsample.0.* -> *.downsample.conv.*
    .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
    // Map *.downsample.1.* -> *.downsample.bn.*
    .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2");
Enter fullscreen mode Exit fullscreen mode

Another discrepancy related to another nn.Sequential block also exists in our module definition. Recall that we instead defined a LayerBlock containing a vector of residual blocks, therefore we must adjust the mapping for any layer block. For example, all blocks in layer1 should map from layer1.[block_idx].* to layer1.blocks.[block_idx].*.

To generalize this to all layer blocks (i.e., layer1, layer2, layer3 and layer4), it would look something like this:

let load_args = LoadArgs::new(weights_path)
    // Map layer[i].[j].* -> layer[i].blocks.[j].*
    .with_key_remap("(layer[1-4])\\.([0-9]+)\\.(.+)", "$1.blocks.$2.$3");
Enter fullscreen mode Exit fullscreen mode

With the combined patterns we just described, the pre-trained ResNet-18 weights can be imported in our main program as follows.

main.rs

The main program requires the following imports.

// NOTE: we used `resnet_burn` as the library crate name but you can replace it with your crate name
use resnet_burn::model::{imagenet, resnet::ResNet};

use burn::{
    backend::NdArray,
    module::Module,
    record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
    tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
Enter fullscreen mode Exit fullscreen mode



// File: main.rs

const TORCH_WEIGHTS: &str = "resnet18-f37072fd.pth";

fn main() {
    // Create ResNet-18 for ImageNet (1k classes)
    let device = Default::default();
    let model: ResNet<NdArray, _> = ResNet::resnet18(1000, &device);

    // Load weights from torch state_dict
    let load_args = LoadArgs::new(TORCH_WEIGHTS.into())
        // Map *.downsample.0.* -> *.downsample.conv.*
        .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
        // Map *.downsample.1.* -> *.downsample.bn.*
        .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2")
        // Map layer[i].[j].* -> layer[i].blocks.[j].*
        .with_key_remap("(layer[1-4])\\.([0-9]+)\\.(.+)", "$1.blocks.$2.$3");
    let record = PyTorchFileRecorder::<FullPrecisionSettings>::new()
        .load(load_args, &device)
        .expect("Should load PyTorch model weights");

    // `load_record` takes ownership of the model but we can re-assign the returned value
    let model = model.load_record(record);
}
Enter fullscreen mode Exit fullscreen mode

Note that for this tutorial we are using the NdArray backend, but you can easily change that to any other supported backend.

Before we test the pre-trained model we just imported, let's save it into one of Burn's supported formats so we can load it without having to convert the state dictionary every time.

Saving and Loading Pre-trained Weights with Burn

It's time to take a page straight out of the Burn Book to save the pre-trained model we just imported to a supported format with one of the Recorders. In this example we'll use the default NamedMpkFileRecorder.

// File: main.rs

// Save the model to a supported format and load it back
const MODEL_PATH: &str = "resnet18-ImageNet1k";
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
    .clone() // `save_file` takes ownership but we want to load the file after
    .save_file(MODEL_PATH, &recorder)
    .expect("Should be able to save weights to file");
let model = model
    .load_file(MODEL_PATH, &recorder, &device)
    .expect("Should be able to load weights from file");
Enter fullscreen mode Exit fullscreen mode

That's it! It's as simple as that. Be advised that in this particular case we had to clone() the model when we saved it but that's only because save_file takes ownership of the model and we actually still want to use it after saving it. In your typical application, you most likely won't need to do that since the saving and loading are usually done separately. Note that cloning only increases the reference count on parameters, no tensor data is actually copied here.

We're now ready to load a sample image for inference.

Loading and Preparing a Sample Image

For this example, we'll use this image of a yellow Labrador retriever as it is one of the ImageNet classes.

With the help of the image crate, loading an image from disk is fairly straightforward.

// File: main.rs

// Load image
let img_path = std::env::args().nth(1).expect("No image path provided");
let img = image::open(img_path).expect("Should be able to load image");
Enter fullscreen mode Exit fullscreen mode

Following the training pre-processing operations, the image should be resized to 224×224224 \times 224 [3] and normalized using the ImageNet statistics.

Let's start by resizing the image to the correct spatial dimensions.

// File: main.rs

// Resize to 224x224
const HEIGHT: usize = 224;
const WIDTH: usize = 224;
let resized_img = img.resize_exact(
    WIDTH as u32,
    HEIGHT as u32,
    image::imageops::FilterType::Triangle, // also known as bilinear in 2D
);
Enter fullscreen mode Exit fullscreen mode

To create a tensor from the image data, we'll define a function to_tensor that will also convert the image tensor to channels first [C, H, W] and divide the values by 255 to normalize between [0, 1]. While we're at it, we'll add a batch dimension B of size one.

// File: main.rs

fn to_tensor<B: Backend, T: Element>(
    data: Vec<T>,
    shape: [usize; 3],
    device: &Device<B>,
) -> Tensor<B, 3> {
    Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
        // permute(2, 0, 1)
        .swap_dims(2, 1) // [H, C, W]
        .swap_dims(1, 0) // [C, H, W]
        / 255 // normalize between [0, 1]
}

// Create tensor from image data
let img_tensor: Tensor<NdArray, 4> = to_tensor(
    resized_img.into_rgb8().into_raw(),
    [HEIGHT, WIDTH, 3],
    &device,
)
.unsqueeze::<4>(); // [B, C, H, W]
Enter fullscreen mode Exit fullscreen mode

Now that we have an image tensor of shape [1, 3, 224, 224], we can apply per-channel normalization with the ImageNet mean and standard deviation values.

model/imagenet.rs

The ImageNet utilities module requires the following imports.

use burn::tensor::{backend::Backend, Device, Tensor};
Enter fullscreen mode Exit fullscreen mode

Since this is something we might want to reuse in the future, we'll define a Normalizer with the correct ImageNet statistics.

// File: model/imagenet.rs

// Pre-computed mean and standard deviation on ImageNet dataset.
// Source: https://github.com/pytorch/vision/blob/main/torchvision/transforms/_presets.py#L44-L45
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const STD: [f32; 3] = [0.229, 0.224, 0.225];

/// Normalizer for the ImageNet dataset.
pub struct Normalizer<B: Backend> {
    pub mean: Tensor<B, 4>,
    pub std: Tensor<B, 4>,
}

impl<B: Backend> Normalizer<B> {
    pub fn new(device: &Device<B>) -> Self {
        let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
        let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]);
        Self { mean, std }
    }

    pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        (input - self.mean.clone()) / self.std.clone()
    }
}
Enter fullscreen mode Exit fullscreen mode

We can use it in our main program now to normalize the image tensor.

// File: main.rs

// Normalize the image
let x = imagenet::Normalizer::new(&device).normalize(img_tensor);
Enter fullscreen mode Exit fullscreen mode

The pre-processed image is now ready to be passed to our ResNet-18 for inference.

Performing Inference

We're done setting the stage. The pre-trained model has been loaded and the image tensor has been pre-processed. Let's look at how the network performs on a sample image.

Note that the sample below doesn't define the CLASSES constant as it holds the 1000 ImageNet categories. To complete your code sample, check out the collapsible source below.

imagenet::CLASSES
// File: model/imagenet.rs

// ImageNet categories
pub const CLASSES: [&str; 1000] = [
    "tench",
    "goldfish",
    "great white shark",
    "tiger shark",
    "hammerhead",
    "electric ray",
    "stingray",
    "cock",
    "hen",
    "ostrich",
    "brambling",
    "goldfinch",
    "house finch",
    "junco",
    "indigo bunting",
    "robin",
    "bulbul",
    "jay",
    "magpie",
    "chickadee",
    "water ouzel",
    "kite",
    "bald eagle",
    "vulture",
    "great grey owl",
    "European fire salamander",
    "common newt",
    "eft",
    "spotted salamander",
    "axolotl",
    "bullfrog",
    "tree frog",
    "tailed frog",
    "loggerhead",
    "leatherback turtle",
    "mud turtle",
    "terrapin",
    "box turtle",
    "banded gecko",
    "common iguana",
    "American chameleon",
    "whiptail",
    "agama",
    "frilled lizard",
    "alligator lizard",
    "Gila monster",
    "green lizard",
    "African chameleon",
    "Komodo dragon",
    "African crocodile",
    "American alligator",
    "triceratops",
    "thunder snake",
    "ringneck snake",
    "hognose snake",
    "green snake",
    "king snake",
    "garter snake",
    "water snake",
    "vine snake",
    "night snake",
    "boa constrictor",
    "rock python",
    "Indian cobra",
    "green mamba",
    "sea snake",
    "horned viper",
    "diamondback",
    "sidewinder",
    "trilobite",
    "harvestman",
    "scorpion",
    "black and gold garden spider",
    "barn spider",
    "garden spider",
    "black widow",
    "tarantula",
    "wolf spider",
    "tick",
    "centipede",
    "black grouse",
    "ptarmigan",
    "ruffed grouse",
    "prairie chicken",
    "peacock",
    "quail",
    "partridge",
    "African grey",
    "macaw",
    "sulphur-crested cockatoo",
    "lorikeet",
    "coucal",
    "bee eater",
    "hornbill",
    "hummingbird",
    "jacamar",
    "toucan",
    "drake",
    "red-breasted merganser",
    "goose",
    "black swan",
    "tusker",
    "echidna",
    "platypus",
    "wallaby",
    "koala",
    "wombat",
    "jellyfish",
    "sea anemone",
    "brain coral",
    "flatworm",
    "nematode",
    "conch",
    "snail",
    "slug",
    "sea slug",
    "chiton",
    "chambered nautilus",
    "Dungeness crab",
    "rock crab",
    "fiddler crab",
    "king crab",
    "American lobster",
    "spiny lobster",
    "crayfish",
    "hermit crab",
    "isopod",
    "white stork",
    "black stork",
    "spoonbill",
    "flamingo",
    "little blue heron",
    "American egret",
    "bittern",
    "crane bird",
    "limpkin",
    "European gallinule",
    "American coot",
    "bustard",
    "ruddy turnstone",
    "red-backed sandpiper",
    "redshank",
    "dowitcher",
    "oystercatcher",
    "pelican",
    "king penguin",
    "albatross",
    "grey whale",
    "killer whale",
    "dugong",
    "sea lion",
    "Chihuahua",
    "Japanese spaniel",
    "Maltese dog",
    "Pekinese",
    "Shih-Tzu",
    "Blenheim spaniel",
    "papillon",
    "toy terrier",
    "Rhodesian ridgeback",
    "Afghan hound",
    "basset",
    "beagle",
    "bloodhound",
    "bluetick",
    "black-and-tan coonhound",
    "Walker hound",
    "English foxhound",
    "redbone",
    "borzoi",
    "Irish wolfhound",
    "Italian greyhound",
    "whippet",
    "Ibizan hound",
    "Norwegian elkhound",
    "otterhound",
    "Saluki",
    "Scottish deerhound",
    "Weimaraner",
    "Staffordshire bullterrier",
    "American Staffordshire terrier",
    "Bedlington terrier",
    "Border terrier",
    "Kerry blue terrier",
    "Irish terrier",
    "Norfolk terrier",
    "Norwich terrier",
    "Yorkshire terrier",
    "wire-haired fox terrier",
    "Lakeland terrier",
    "Sealyham terrier",
    "Airedale",
    "cairn",
    "Australian terrier",
    "Dandie Dinmont",
    "Boston bull",
    "miniature schnauzer",
    "giant schnauzer",
    "standard schnauzer",
    "Scotch terrier",
    "Tibetan terrier",
    "silky terrier",
    "soft-coated wheaten terrier",
    "West Highland white terrier",
    "Lhasa",
    "flat-coated retriever",
    "curly-coated retriever",
    "golden retriever",
    "Labrador retriever",
    "Chesapeake Bay retriever",
    "German short-haired pointer",
    "vizsla",
    "English setter",
    "Irish setter",
    "Gordon setter",
    "Brittany spaniel",
    "clumber",
    "English springer",
    "Welsh springer spaniel",
    "cocker spaniel",
    "Sussex spaniel",
    "Irish water spaniel",
    "kuvasz",
    "schipperke",
    "groenendael",
    "malinois",
    "briard",
    "kelpie",
    "komondor",
    "Old English sheepdog",
    "Shetland sheepdog",
    "collie",
    "Border collie",
    "Bouvier des Flandres",
    "Rottweiler",
    "German shepherd",
    "Doberman",
    "miniature pinscher",
    "Greater Swiss Mountain dog",
    "Bernese mountain dog",
    "Appenzeller",
    "EntleBucher",
    "boxer",
    "bull mastiff",
    "Tibetan mastiff",
    "French bulldog",
    "Great Dane",
    "Saint Bernard",
    "Eskimo dog",
    "malamute",
    "Siberian husky",
    "dalmatian",
    "affenpinscher",
    "basenji",
    "pug",
    "Leonberg",
    "Newfoundland",
    "Great Pyrenees",
    "Samoyed",
    "Pomeranian",
    "chow",
    "keeshond",
    "Brabancon griffon",
    "Pembroke",
    "Cardigan",
    "toy poodle",
    "miniature poodle",
    "standard poodle",
    "Mexican hairless",
    "timber wolf",
    "white wolf",
    "red wolf",
    "coyote",
    "dingo",
    "dhole",
    "African hunting dog",
    "hyena",
    "red fox",
    "kit fox",
    "Arctic fox",
    "grey fox",
    "tabby",
    "tiger cat",
    "Persian cat",
    "Siamese cat",
    "Egyptian cat",
    "cougar",
    "lynx",
    "leopard",
    "snow leopard",
    "jaguar",
    "lion",
    "tiger",
    "cheetah",
    "brown bear",
    "American black bear",
    "ice bear",
    "sloth bear",
    "mongoose",
    "meerkat",
    "tiger beetle",
    "ladybug",
    "ground beetle",
    "long-horned beetle",
    "leaf beetle",
    "dung beetle",
    "rhinoceros beetle",
    "weevil",
    "fly",
    "bee",
    "ant",
    "grasshopper",
    "cricket",
    "walking stick",
    "cockroach",
    "mantis",
    "cicada",
    "leafhopper",
    "lacewing",
    "dragonfly",
    "damselfly",
    "admiral",
    "ringlet",
    "monarch",
    "cabbage butterfly",
    "sulphur butterfly",
    "lycaenid",
    "starfish",
    "sea urchin",
    "sea cucumber",
    "wood rabbit",
    "hare",
    "Angora",
    "hamster",
    "porcupine",
    "fox squirrel",
    "marmot",
    "beaver",
    "guinea pig",
    "sorrel",
    "zebra",
    "hog",
    "wild boar",
    "warthog",
    "hippopotamus",
    "ox",
    "water buffalo",
    "bison",
    "ram",
    "bighorn",
    "ibex",
    "hartebeest",
    "impala",
    "gazelle",
    "Arabian camel",
    "llama",
    "weasel",
    "mink",
    "polecat",
    "black-footed ferret",
    "otter",
    "skunk",
    "badger",
    "armadillo",
    "three-toed sloth",
    "orangutan",
    "gorilla",
    "chimpanzee",
    "gibbon",
    "siamang",
    "guenon",
    "patas",
    "baboon",
    "macaque",
    "langur",
    "colobus",
    "proboscis monkey",
    "marmoset",
    "capuchin",
    "howler monkey",
    "titi",
    "spider monkey",
    "squirrel monkey",
    "Madagascar cat",
    "indri",
    "Indian elephant",
    "African elephant",
    "lesser panda",
    "giant panda",
    "barracouta",
    "eel",
    "coho",
    "rock beauty",
    "anemone fish",
    "sturgeon",
    "gar",
    "lionfish",
    "puffer",
    "abacus",
    "abaya",
    "academic gown",
    "accordion",
    "acoustic guitar",
    "aircraft carrier",
    "airliner",
    "airship",
    "altar",
    "ambulance",
    "amphibian",
    "analog clock",
    "apiary",
    "apron",
    "ashcan",
    "assault rifle",
    "backpack",
    "bakery",
    "balance beam",
    "balloon",
    "ballpoint",
    "Band Aid",
    "banjo",
    "bannister",
    "barbell",
    "barber chair",
    "barbershop",
    "barn",
    "barometer",
    "barrel",
    "barrow",
    "baseball",
    "basketball",
    "bassinet",
    "bassoon",
    "bathing cap",
    "bath towel",
    "bathtub",
    "beach wagon",
    "beacon",
    "beaker",
    "bearskin",
    "beer bottle",
    "beer glass",
    "bell cote",
    "bib",
    "bicycle-built-for-two",
    "bikini",
    "binder",
    "binoculars",
    "birdhouse",
    "boathouse",
    "bobsled",
    "bolo tie",
    "bonnet",
    "bookcase",
    "bookshop",
    "bottlecap",
    "bow",
    "bow tie",
    "brass",
    "brassiere",
    "breakwater",
    "breastplate",
    "broom",
    "bucket",
    "buckle",
    "bulletproof vest",
    "bullet train",
    "butcher shop",
    "cab",
    "caldron",
    "candle",
    "cannon",
    "canoe",
    "can opener",
    "cardigan",
    "car mirror",
    "carousel",
    "carpenter's kit",
    "carton",
    "car wheel",
    "cash machine",
    "cassette",
    "cassette player",
    "castle",
    "catamaran",
    "CD player",
    "cello",
    "cellular telephone",
    "chain",
    "chainlink fence",
    "chain mail",
    "chain saw",
    "chest",
    "chiffonier",
    "chime",
    "china cabinet",
    "Christmas stocking",
    "church",
    "cinema",
    "cleaver",
    "cliff dwelling",
    "cloak",
    "clog",
    "cocktail shaker",
    "coffee mug",
    "coffeepot",
    "coil",
    "combination lock",
    "computer keyboard",
    "confectionery",
    "container ship",
    "convertible",
    "corkscrew",
    "cornet",
    "cowboy boot",
    "cowboy hat",
    "cradle",
    "crane",
    "crash helmet",
    "crate",
    "crib",
    "Crock Pot",
    "croquet ball",
    "crutch",
    "cuirass",
    "dam",
    "desk",
    "desktop computer",
    "dial telephone",
    "diaper",
    "digital clock",
    "digital watch",
    "dining table",
    "dishrag",
    "dishwasher",
    "disk brake",
    "dock",
    "dogsled",
    "dome",
    "doormat",
    "drilling platform",
    "drum",
    "drumstick",
    "dumbbell",
    "Dutch oven",
    "electric fan",
    "electric guitar",
    "electric locomotive",
    "entertainment center",
    "envelope",
    "espresso maker",
    "face powder",
    "feather boa",
    "file",
    "fireboat",
    "fire engine",
    "fire screen",
    "flagpole",
    "flute",
    "folding chair",
    "football helmet",
    "forklift",
    "fountain",
    "fountain pen",
    "four-poster",
    "freight car",
    "French horn",
    "frying pan",
    "fur coat",
    "garbage truck",
    "gasmask",
    "gas pump",
    "goblet",
    "go-kart",
    "golf ball",
    "golfcart",
    "gondola",
    "gong",
    "gown",
    "grand piano",
    "greenhouse",
    "grille",
    "grocery store",
    "guillotine",
    "hair slide",
    "hair spray",
    "half track",
    "hammer",
    "hamper",
    "hand blower",
    "hand-held computer",
    "handkerchief",
    "hard disc",
    "harmonica",
    "harp",
    "harvester",
    "hatchet",
    "holster",
    "home theater",
    "honeycomb",
    "hook",
    "hoopskirt",
    "horizontal bar",
    "horse cart",
    "hourglass",
    "iPod",
    "iron",
    "jack-o'-lantern",
    "jean",
    "jeep",
    "jersey",
    "jigsaw puzzle",
    "jinrikisha",
    "joystick",
    "kimono",
    "knee pad",
    "knot",
    "lab coat",
    "ladle",
    "lampshade",
    "laptop",
    "lawn mower",
    "lens cap",
    "letter opener",
    "library",
    "lifeboat",
    "lighter",
    "limousine",
    "liner",
    "lipstick",
    "Loafer",
    "lotion",
    "loudspeaker",
    "loupe",
    "lumbermill",
    "magnetic compass",
    "mailbag",
    "mailbox",
    "maillot",
    "maillot tank suit",
    "manhole cover",
    "maraca",
    "marimba",
    "mask",
    "matchstick",
    "maypole",
    "maze",
    "measuring cup",
    "medicine chest",
    "megalith",
    "microphone",
    "microwave",
    "military uniform",
    "milk can",
    "minibus",
    "miniskirt",
    "minivan",
    "missile",
    "mitten",
    "mixing bowl",
    "mobile home",
    "Model T",
    "modem",
    "monastery",
    "monitor",
    "moped",
    "mortar",
    "mortarboard",
    "mosque",
    "mosquito net",
    "motor scooter",
    "mountain bike",
    "mountain tent",
    "mouse",
    "mousetrap",
    "moving van",
    "muzzle",
    "nail",
    "neck brace",
    "necklace",
    "nipple",
    "notebook",
    "obelisk",
    "oboe",
    "ocarina",
    "odometer",
    "oil filter",
    "organ",
    "oscilloscope",
    "overskirt",
    "oxcart",
    "oxygen mask",
    "packet",
    "paddle",
    "paddlewheel",
    "padlock",
    "paintbrush",
    "pajama",
    "palace",
    "panpipe",
    "paper towel",
    "parachute",
    "parallel bars",
    "park bench",
    "parking meter",
    "passenger car",
    "patio",
    "pay-phone",
    "pedestal",
    "pencil box",
    "pencil sharpener",
    "perfume",
    "Petri dish",
    "photocopier",
    "pick",
    "pickelhaube",
    "picket fence",
    "pickup",
    "pier",
    "piggy bank",
    "pill bottle",
    "pillow",
    "ping-pong ball",
    "pinwheel",
    "pirate",
    "pitcher",
    "plane",
    "planetarium",
    "plastic bag",
    "plate rack",
    "plow",
    "plunger",
    "Polaroid camera",
    "pole",
    "police van",
    "poncho",
    "pool table",
    "pop bottle",
    "pot",
    "potter's wheel",
    "power drill",
    "prayer rug",
    "printer",
    "prison",
    "projectile",
    "projector",
    "puck",
    "punching bag",
    "purse",
    "quill",
    "quilt",
    "racer",
    "racket",
    "radiator",
    "radio",
    "radio telescope",
    "rain barrel",
    "recreational vehicle",
    "reel",
    "reflex camera",
    "refrigerator",
    "remote control",
    "restaurant",
    "revolver",
    "rifle",
    "rocking chair",
    "rotisserie",
    "rubber eraser",
    "rugby ball",
    "rule",
    "running shoe",
    "safe",
    "safety pin",
    "saltshaker",
    "sandal",
    "sarong",
    "sax",
    "scabbard",
    "scale",
    "school bus",
    "schooner",
    "scoreboard",
    "screen",
    "screw",
    "screwdriver",
    "seat belt",
    "sewing machine",
    "shield",
    "shoe shop",
    "shoji",
    "shopping basket",
    "shopping cart",
    "shovel",
    "shower cap",
    "shower curtain",
    "ski",
    "ski mask",
    "sleeping bag",
    "slide rule",
    "sliding door",
    "slot",
    "snorkel",
    "snowmobile",
    "snowplow",
    "soap dispenser",
    "soccer ball",
    "sock",
    "solar dish",
    "sombrero",
    "soup bowl",
    "space bar",
    "space heater",
    "space shuttle",
    "spatula",
    "speedboat",
    "spider web",
    "spindle",
    "sports car",
    "spotlight",
    "stage",
    "steam locomotive",
    "steel arch bridge",
    "steel drum",
    "stethoscope",
    "stole",
    "stone wall",
    "stopwatch",
    "stove",
    "strainer",
    "streetcar",
    "stretcher",
    "studio couch",
    "stupa",
    "submarine",
    "suit",
    "sundial",
    "sunglass",
    "sunglasses",
    "sunscreen",
    "suspension bridge",
    "swab",
    "sweatshirt",
    "swimming trunks",
    "swing",
    "switch",
    "syringe",
    "table lamp",
    "tank",
    "tape player",
    "teapot",
    "teddy",
    "television",
    "tennis ball",
    "thatch",
    "theater curtain",
    "thimble",
    "thresher",
    "throne",
    "tile roof",
    "toaster",
    "tobacco shop",
    "toilet seat",
    "torch",
    "totem pole",
    "tow truck",
    "toyshop",
    "tractor",
    "trailer truck",
    "tray",
    "trench coat",
    "tricycle",
    "trimaran",
    "tripod",
    "triumphal arch",
    "trolleybus",
    "trombone",
    "tub",
    "turnstile",
    "typewriter keyboard",
    "umbrella",
    "unicycle",
    "upright",
    "vacuum",
    "vase",
    "vault",
    "velvet",
    "vending machine",
    "vestment",
    "viaduct",
    "violin",
    "volleyball",
    "waffle iron",
    "wall clock",
    "wallet",
    "wardrobe",
    "warplane",
    "washbasin",
    "washer",
    "water bottle",
    "water jug",
    "water tower",
    "whiskey jug",
    "whistle",
    "wig",
    "window screen",
    "window shade",
    "Windsor tie",
    "wine bottle",
    "wing",
    "wok",
    "wooden spoon",
    "wool",
    "worm fence",
    "wreck",
    "yawl",
    "yurt",
    "web site",
    "comic book",
    "crossword puzzle",
    "street sign",
    "traffic light",
    "book jacket",
    "menu",
    "plate",
    "guacamole",
    "consomme",
    "hot pot",
    "trifle",
    "ice cream",
    "ice lolly",
    "French loaf",
    "bagel",
    "pretzel",
    "cheeseburger",
    "hotdog",
    "mashed potato",
    "head cabbage",
    "broccoli",
    "cauliflower",
    "zucchini",
    "spaghetti squash",
    "acorn squash",
    "butternut squash",
    "cucumber",
    "artichoke",
    "bell pepper",
    "cardoon",
    "mushroom",
    "Granny Smith",
    "strawberry",
    "orange",
    "lemon",
    "fig",
    "pineapple",
    "banana",
    "jackfruit",
    "custard apple",
    "pomegranate",
    "hay",
    "carbonara",
    "chocolate sauce",
    "dough",
    "meat loaf",
    "pizza",
    "potpie",
    "burrito",
    "red wine",
    "espresso",
    "cup",
    "eggnog",
    "alp",
    "bubble",
    "cliff",
    "coral reef",
    "geyser",
    "lakeside",
    "promontory",
    "sandbar",
    "seashore",
    "valley",
    "volcano",
    "ballplayer",
    "groom",
    "scuba diver",
    "rapeseed",
    "daisy",
    "yellow lady's slipper",
    "corn",
    "acorn",
    "hip",
    "buckeye",
    "coral fungus",
    "agaric",
    "gyromitra",
    "stinkhorn",
    "earthstar",
    "hen-of-the-woods",
    "bolete",
    "ear",
    "toilet tissue",
];
Enter fullscreen mode Exit fullscreen mode

// File: main.rs

// Forward pass
let out = model.forward(x);

// Output class index with score (raw)
let (score, idx) = out.max_dim_with_indices(1);
let idx = idx.into_scalar() as usize;

println!(
    "Predicted: {}\nCategory Id: {}\nScore: {:.4}",
    imagenet::CLASSES[idx],
    idx,
    score.into_scalar()
);
Enter fullscreen mode Exit fullscreen mode

With the snippets collected from this tutorial, you should now be able to run the project executable. To test the inference on a sample image, you can download the sample image here and run the command below.

cargo run --release YellowLabradorLooking_new.jpg
Enter fullscreen mode Exit fullscreen mode

Looks like the model recognized the dog correctly!

Yellow Labrador retriever

Predicted: Labrador retriever
Category Id: 208
Score: 14.4482
Enter fullscreen mode Exit fullscreen mode

Note: There is no need to call model.valid() (the equivalent to PyTorch's .eval()) here since we are not using an AutodiffBackend. Without the AutodiffBackend, no backward graph is created and thus no gradients are tracked. More importantly, nothing related to automatic differentiation is actually compiled, keeping the code as lean as possible for inference with minimal dependencies.

If you want to use ResNet in your application, take a look at the official Burn implementation available on GitHub! It closely follows this tutorial's implementation but further extends it to provide an easy interface to load the pre-trained weights for the whole ResNet family of models.

Still want more? Take a look at the Burn Book or any of our examples. And if you have any questions or comments, don't hesitate do join our Discord!


[1] The bottleneck of torchvision places the stride for downsampling to the second 3×33 \times 3 convolution while the original paper places it to the first 1×11 \times 1 convolution. This variant improves the accuracy and is known as ResNet V1.5.

[2] Note that the code presented in this tutorial is simplified to encapsulate the most basic ResNet implementation. The full implementation with matching weight initialization scheme is available on GitHub. Equivalently, the PyTorch implementation presented was simplified to eliminate as much boilerplate configuration as possible to keep it succinct for comparison's sake. If you're curious about the complete PyTorch implementation, take a look at torchvision.models.resnet.ResNet.

[3] Although common practice is to resize 256×256256 \times 256 followed by a center crop 224×224224 \times 224 when evaluating these ImageNet pre-trained models, a simple resize operation to 224×224224 \times 224 is sufficient.

Top comments (1)

Collapse
 
louisfd profile image
Louis Fortier-Dubois

Very comprehensive tutorial 🦮!