From 0e445a9680969be4e9e5811e8d5437514cb5a9e1 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 29 Aug 2024 08:58:51 -0400 Subject: [PATCH] Fix target convert in batcher and align guide imports (#2215) * Fix target convert in batcher * Align hidden code in training and update loss to use config * Align imports with example * Remove unused import and fix guide->crate --- burn-book/src/basic-workflow/data.md | 10 +++-- burn-book/src/basic-workflow/inference.md | 7 +-- burn-book/src/basic-workflow/model.md | 2 +- burn-book/src/basic-workflow/training.md | 52 ++++++++++------------- 4 files changed, 32 insertions(+), 39 deletions(-) diff --git a/burn-book/src/basic-workflow/data.md b/burn-book/src/basic-workflow/data.md index 4e3683c219..dab324e950 100644 --- a/burn-book/src/basic-workflow/data.md +++ b/burn-book/src/basic-workflow/data.md @@ -79,10 +79,12 @@ impl Batcher> for MnistBatcher { let targets = items .iter() - .map(|item| Tensor::::from_data( - TensorData::from([(item.label as i64).elem()]), - &self.device - )) + .map(|item| { + Tensor::::from_data( + [(item.label as i64).elem::()], + &self.device, + ) + }) .collect(); let images = Tensor::cat(images, 0).to_device(&self.device); diff --git a/burn-book/src/basic-workflow/inference.md b/burn-book/src/basic-workflow/inference.md index 1195055ae7..88ae9afc76 100644 --- a/burn-book/src/basic-workflow/inference.md +++ b/burn-book/src/basic-workflow/inference.md @@ -10,16 +10,13 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic load our trained model. ```rust , ignore +# use crate::{data::MnistBatcher, training::TrainingConfig}; # use burn::{ -# config::Config, # data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, -# module::Module, +# prelude::*, # record::{CompactRecorder, Recorder}, -# tensor::backend::Backend, # }; # -# use crate::{data::MnistBatcher, training::TrainingConfig}; -# pub fn infer(artifact_dir: &str, device: B::Device, item: MnistItem) { let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) .expect("Config should exist for the model"); diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index 28d0682a4b..8952b78d64 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -221,8 +221,8 @@ impl ModelConfig { At a glance, you can view the model configuration by printing the model instance: ```rust , ignore +use crate::model::ModelConfig; use burn::backend::Wgpu; -use guide::model::ModelConfig; fn main() { type MyBackend = Wgpu; diff --git a/burn-book/src/basic-workflow/training.md b/burn-book/src/basic-workflow/training.md index cc40a6e3ed..6705beed17 100644 --- a/burn-book/src/basic-workflow/training.md +++ b/burn-book/src/basic-workflow/training.md @@ -39,7 +39,9 @@ impl Model { targets: Tensor, ) -> ClassificationOutput { let output = self.forward(images); - let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), targets.clone()); ClassificationOutput::new(loss, output, targets) } @@ -60,28 +62,23 @@ Moving forward, we will proceed with the implementation of both the training and for our model. ```rust , ignore +# use crate::{ +# data::{MnistBatch, MnistBatcher}, +# model::{Model, ModelConfig}, +# }; # use burn::{ -# config::Config, # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, -# module::Module, -# nn::loss::CrossEntropyLoss, +# nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, +# prelude::*, # record::CompactRecorder, -# tensor::{ -# backend::{AutodiffBackend, Backend}, -# Int, Tensor, -# }, +# tensor::backend::AutodiffBackend, # train::{ # metric::{AccuracyMetric, LossMetric}, # ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, # }, # }; # -# use crate::{ -# data::{MnistBatch, MnistBatcher}, -# model::{Model, ModelConfig}, -# }; -# # impl Model { # pub fn forward_classification( # &self, @@ -89,8 +86,9 @@ for our model. # targets: Tensor, # ) -> ClassificationOutput { # let output = self.forward(images); -# let loss = -# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); +# let loss = CrossEntropyLossConfig::new() +# .init(&output.device()) +# .forward(output.clone(), targets.clone()); # # ClassificationOutput::new(loss, output, targets) # } @@ -147,28 +145,23 @@ Book. Let us move on to establishing the practical training configuration. ```rust , ignore +# use crate::{ +# data::{MnistBatch, MnistBatcher}, +# model::{Model, ModelConfig}, +# }; # use burn::{ -# config::Config, # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, -# module::Module, -# nn::loss::CrossEntropyLoss, +# nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, +# prelude::*, # record::CompactRecorder, -# tensor::{ -# backend::{AutodiffBackend, Backend}, -# Int, Tensor, -# }, +# tensor::backend::AutodiffBackend, # train::{ # metric::{AccuracyMetric, LossMetric}, # ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, # }, # }; # -# use crate::{ -# data::{MnistBatch, MnistBatcher}, -# model::{Model, ModelConfig}, -# }; -# # impl Model { # pub fn forward_classification( # &self, @@ -176,8 +169,9 @@ Let us move on to establishing the practical training configuration. # targets: Tensor, # ) -> ClassificationOutput { # let output = self.forward(images); -# let loss = -# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); +# let loss = CrossEntropyLossConfig::new() +# .init(&output.device()) +# .forward(output.clone(), targets.clone()); # # ClassificationOutput::new(loss, output, targets) # }