DEV Community

Favil Orbedios
Favil Orbedios

Posted on

fast.ai Book in Rust - Chapter 2 - Part 2

Introduction

In the last article we made our datasets and learners more generic. Before writing this article, I realized we don't have any indication of the loss of our model during training, or the validation losses after training for the validation set. Let's add that capability in this article.

Saving weights is easy

So I want to save the weights of the model, and I'd like to do it often so that we can kill the program and pick up from where we left off fairly easily.

So I think I've decided to make an option on our Learner builder to save every batch. And while we're at it, so I don't have to update this code over multiple passes, we also need a way to check the validation loss of our learner. And it would be nice to log out the training loss each epoch, and even more, we can change the start epoch and load the old models so we can start from the middle of a training session. So let's just add all of this stuff! I'll comment about all the new stuff.

    pub struct Builder<'a, T, const N: usize, Category> {
        pub(super) device: AutoDevice,
        // Renamed this
        pub(super) train_dataset: Option<&'a DirectoryImageDataset<'a, N, Category>>,
        // Added validation set
        pub(super) valid_dataset: Option<&'a DirectoryImageDataset<'a, N, Category>>,
        pub(super) model: Option<Resnet34Model<N, f32>>,
        // Save each block
        pub(super) save_each_block: bool,
        // When should we start?
        pub(super) start_epoch: usize,
        pub(super) _phantom: PhantomData<T>,
    }

Enter fullscreen mode Exit fullscreen mode

Then we can add an impl block to Builder that is generic over type T.

impl<'a, const N: usize, Category: IntoOneHot<N>, T> Builder<'a, T, N, Category> {
    // Should we save each block?
    pub fn save_each_block(mut self) -> Self {
        self.save_each_block = true;
        self
    }

    // Start epoch to builder
    pub fn start_epoch(mut self, start_epoch: usize) -> Self {
        self.start_epoch = start_epoch;
        self
    }

    // Add the validation dataset to the builder
    pub fn with_valid_dataset(
        mut self,
        valid_dataset: &'a DirectoryImageDataset<'a, N, Category>,
    ) -> Self {
        self.valid_dataset = Some(valid_dataset);
        self
    }
}
Enter fullscreen mode Exit fullscreen mode

And now we can change the build() method to return the new structure of the VisualLearner.

    impl<'a, const N: usize, Category: IntoOneHot<N>> Builder<'a, Ready, N, Category> {
        pub fn build(self) -> VisualLearner<'a, N, Category> {
            let model = self.model.unwrap();
            VisualLearner::new(
                self.device,
                self.train_dataset.unwrap(),
                self.valid_dataset,
                self.model.unwrap(),
                self.save_each_block,
                self.start_epoch,
            )
        }
    }
Enter fullscreen mode Exit fullscreen mode

And quickly modify the private constructor for VisualLearner

 pub struct VisualLearner<'a, const N: usize, Category> {
    device: AutoDevice,
    train_dataset: &'a DirectoryImageDataset<'a, N, Category>,
    // New, make this optional, so we don't always need to specify it.
    valid_dataset: Option<&'a DirectoryImageDataset<'a, N, Category>>,
    model: Resnet34Model<N, f32>,
    optimizer: Adam<Resnet34Built<N, f32>, f32, AutoDevice>,
    save_each_block: bool,
    start_epoch: usize,
}

const BATCH_SIZE: usize = 16;

impl<'a, const N: usize, Category: IntoOneHot<N>> VisualLearner<'a, N, Category> {
    // [snip builder]
   fn new(
        device: AutoDevice,
        train_dataset: &'a DirectoryImageDataset<'a, N, Category>,
        valid_dataset: Option<&'a DirectoryImageDataset<'a, N, Category>>,
        model: Resnet34Model<N, f32>,
        save_each_block: bool,
        start_epoch: usize,
    ) -> Self {
        let adam = Adam::new(&model.model, AdamConfig::default());
        Self {
            device,
            train_dataset,
            valid_dataset,
            model,
            optimizer: adam,
            save_each_block,
            start_epoch,
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Now we just need to update the learn() method.

pub fn train(&mut self, epochs: usize) -> Result<(), Error> {
    let mut rng = rand::thread_rng();
    let mut grads = self.model.model.alloc_grads();
    let mut total_epoch_loss = 0.0;
    let mut num_batches = 0;
    let start = Instant::now();
    for epoch in self.start_epoch..self.start_epoch + epochs {
        log::info!("Epoch {}", epoch);
        for (image, is_cat) in self
            .train_dataset
            .shuffled(&mut rng)
            .map(Result::unwrap)
            .map(|(image, is_cat)| (image, is_cat.into_one_hot(&self.device)))
            .batch_exact(Const::<BATCH_SIZE>)
            .collate()
            .stack()
            .progress()
        {
            let logits = self.model.model.forward_mut(image.traced(grads));
            let loss = cross_entropy_with_logits_loss(logits, is_cat);
            total_epoch_loss += loss.array();
            num_batches += 1;

            grads = loss.backward();
            self.optimizer.update(&mut self.model.model, &grads)?;
            self.model.model.zero_grads(&mut grads);

            // Save the model after each block
            if self.save_each_block {
                self.save(format!("model-epoch-{}.safetensors", epoch))?;
            }
        }
        let dur = start.elapsed();

        // Log out the stats for the epoch
        log::info!(
            "Epoch {epoch} in {dur:?} ({:.3} batches/s): avg sample loss: {:.5}",
            num_batches as f32 / dur.as_secs_f32(),
            BATCH_SIZE as f32 * total_epoch_loss / num_batches as f32,
        );

        // If we have a valid set, go ahead and log out the validation loss.
        // Though we can remove this if it takes too long.
        if self.valid_dataset.is_some() {
            let valid_loss = self.valid_loss()?;
           log::info!("Valid loss: {:.5}", valid_loss);
        }
    }
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

As you can see, saving the model each batch is a single call to the save() method. Lets define that and valid_loss() now.

impl<'a, const N: usize, Category: IntoOneHot<N>> VisualLearner<'a, N, Category> {
    pub fn valid_loss(&mut self) -> Result<f32, Error> {
        let mut total_epoch_loss = 0.0;
        let mut num_batches = 0;
        log::info!("Calculating validation loss");
        for (img, is_cat) in self
            .valid_dataset
            // If this is called without setting valid_dataset, it's an error
            .ok_or(Error::NoValidationDataset)?
            // NOT `shuffled()`, this just needs to iterate through once
            .iter()
            .map(Result::unwrap)
            .map(|(image, is_cat)| (image, is_cat.into_one_hot(&self.device)))
            .batch_exact(BATCH_SIZE)
            .collate()
            .stack()
            .progress()
        {
            let logits = self.model.model.forward(img);
            let loss = cross_entropy_with_logits_loss(logits, is_cat);
            total_epoch_loss += loss.array();
            num_batches += 1;
        }
        Ok(BATCH_SIZE as f32 * total_epoch_loss / num_batches as f32)
    }

    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), Error> {
        Ok(self.model.model.save_safetensors(path)?)
    }
}
Enter fullscreen mode Exit fullscreen mode

Saving is easy, just a single call to save_safetensors(). The validation loss is a little more complicated. It's very similar to the training loop, but instead of calling shuffled(), we just call iter() because it doesn't matter if it is out of order, we just need to get through the validation set, and add the loss together.

The final step is to actually use this new stuff in the main.rs file. I'm going to make our chapter1 executable a little more complicated by adding some command line options. This is going to be pretty easy with the clap crate, and the derive feature flag.

#[derive(Debug, Parser)]
#[command(author = "Favil Orbedios")]
struct Args {
    /// The seed to create the [AutoDevice] with, default 0
    #[clap(long, short = 's', default_value = "0")]
    seed: u64,

    /// If set, load the model from this file
    #[clap(long = "model", short = 'm')]
    model_file: Option<PathBuf>,

    /// The epoch to start training at, default 0
    #[clap(long = "epoch", short = 'e', default_value = "0")]
    start_epoch: usize,

    /// The number of epochs to train for, default 3
    #[clap(long, short = 'n', default_value = "3")]
    epochs: usize,
}
Enter fullscreen mode Exit fullscreen mode

This structure will be parsed from the command line arguments, so we can run it like

cargo run --release -- --seed 42 --epoch 10 --epochs 5 --model "model-epoch-9.safetensors"
Enter fullscreen mode Exit fullscreen mode

And that will load the model from the model-epoch-9.safetensors file, start at epoch 10, train for 5 epochs, and use the seed 42 to initialize the AutoDevice. How do we use it now. I'll just paste the entire main() function we have so far.

fn main() -> Result<()> {
    env_logger::Builder::new()
        .filter_level(log::LevelFilter::Info)
        .init();
    color_eyre::install()?;

    // NEW: Parse the arguments from the command line
    let args = Args::parse();

    let path: PathBuf = untar_images(DatasetUrl::Pets)
        .context("downloading Pets")?
        .join("images");
    log::info!("Images are in: {}", path.display());

    // NEW: Set the seed directly, instead of using 0
    let dev = AutoDevice::seed_from_u64(args.seed);

    // Silly thing about the Pets dataset, all the cats have a capital first letter in their
    // filename, all the dogs are lowercase only
    let is_cat = |path: &Path| {
        path.file_name()
            .and_then(|n| n.to_str())
            .and_then(|n| n.chars().next().map(|c| c.is_uppercase()))
            .unwrap_or(false)
    };

    let dataset_loader = DirectoryImageDataLoader::builder(path, dev.clone())
        .with_label_fn(&is_cat)
        .with_splitter(RatioSplitter::with_seed_validation(0, 0.1))
        .build()?;
    let dataset = dataset_loader.training();
    log::info!("Found {} files", dataset.files().len());

    log::info!("Building the ResNet-34 model");
    let mut model = Resnet34Model::<2, f32>::build(dev.clone());
    log::info!("Done building model");

    // NEW: If we set the model file, let's load it
    if let Some(model_file) = args.model_file {
        log::info!("Loading old model");
        model.load_model(model_file)?;
        log::info!("Done loading old model");
    }

    let mut learner = VisualLearner::builder(dev.clone())
        .save_each_block()
        // NEW: starting epoch
        .start_epoch(args.start_epoch)
        .with_valid_dataset(dataset_loader.validation())
        .with_train_dataset(dataset)
        .with_model(model)
        .build();

    let valid_loss = learner.valid_loss()?;
    log::info!("Valid loss: {:.5}", valid_loss);

    log::info!("Training");
    // NEW: train for the specified number of epochs
    learner.train(args.epochs)?;
    log::info!("Done training");

    learner.save("model.safetensors")?;

    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

Run times

On my CPU bound laptop, this takes about 5 hours to run in release mode, with default arguments, because my GPU isn't supported yet. If we loaded it up on a Cuda supported computer, it would probably run a lot faster. I may spin up an instance on EC2 and see how the run times compare.

If anyone runs this code on a supported computer with the cuda feature flag, let me know in the comments how much faster it runs.

Conclusion

As always, I've uploaded the code to Github. You can check it out from the chapter-2-2 tag.

Tune in next time to learn about visualizing the images on the command line! I'm very excited to see how we do.

Top comments (0)