Skip to content

Commit

Permalink
Fix target convert in batcher and align guide imports (#2215)
Browse files Browse the repository at this point in the history
* Fix target convert in batcher

* Align hidden code in training and update loss to use config

* Align imports with example

* Remove unused import and fix guide->crate
  • Loading branch information
laggui committed Aug 29, 2024
1 parent 3664c6a commit 0e445a9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
10 changes: 6 additions & 4 deletions burn-book/src/basic-workflow/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {

let targets = items
.iter()
.map(|item| Tensor::<B, 1, Int>::from_data(
TensorData::from([(item.label as i64).elem()]),
&self.device
))
.map(|item| {
Tensor::<B, 1, Int>::from_data(
[(item.label as i64).elem::<B::IntElem>()],
&self.device,
)
})
.collect();

let images = Tensor::cat(images, 0).to_device(&self.device);
Expand Down
7 changes: 2 additions & 5 deletions burn-book/src/basic-workflow/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@ cost. Let's create a simple `infer` method in a new file `src/inference.rs` whic
load our trained model.

```rust , ignore
# use crate::{data::MnistBatcher, training::TrainingConfig};
# use burn::{
# config::Config,
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
# module::Module,
# prelude::*,
# record::{CompactRecorder, Recorder},
# tensor::backend::Backend,
# };
#
# use crate::{data::MnistBatcher, training::TrainingConfig};
#
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model");
Expand Down
2 changes: 1 addition & 1 deletion burn-book/src/basic-workflow/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ impl ModelConfig {
At a glance, you can view the model configuration by printing the model instance:

```rust , ignore
use crate::model::ModelConfig;
use burn::backend::Wgpu;
use guide::model::ModelConfig;

fn main() {
type MyBackend = Wgpu<f32, i32>;
Expand Down
52 changes: 23 additions & 29 deletions burn-book/src/basic-workflow/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ impl<B: Backend> Model<B> {
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
let output = self.forward(images);
let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), targets.clone());

ClassificationOutput::new(loss, output, targets)
}
Expand All @@ -60,37 +62,33 @@ Moving forward, we will proceed with the implementation of both the training and
for our model.

```rust , ignore
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
# use burn::{
# config::Config,
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
# module::Module,
# nn::loss::CrossEntropyLoss,
# nn::loss::CrossEntropyLossConfig,
# optim::AdamConfig,
# prelude::*,
# record::CompactRecorder,
# tensor::{
# backend::{AutodiffBackend, Backend},
# Int, Tensor,
# },
# tensor::backend::AutodiffBackend,
# train::{
# metric::{AccuracyMetric, LossMetric},
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
# },
# };
#
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
#
# impl<B: Backend> Model<B> {
# pub fn forward_classification(
# &self,
# images: Tensor<B, 3>,
# targets: Tensor<B, 1, Int>,
# ) -> ClassificationOutput<B> {
# let output = self.forward(images);
# let loss =
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
# let loss = CrossEntropyLossConfig::new()
# .init(&output.device())
# .forward(output.clone(), targets.clone());
#
# ClassificationOutput::new(loss, output, targets)
# }
Expand Down Expand Up @@ -147,37 +145,33 @@ Book.
Let us move on to establishing the practical training configuration.

```rust , ignore
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
# use burn::{
# config::Config,
# data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
# module::Module,
# nn::loss::CrossEntropyLoss,
# nn::loss::CrossEntropyLossConfig,
# optim::AdamConfig,
# prelude::*,
# record::CompactRecorder,
# tensor::{
# backend::{AutodiffBackend, Backend},
# Int, Tensor,
# },
# tensor::backend::AutodiffBackend,
# train::{
# metric::{AccuracyMetric, LossMetric},
# ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
# },
# };
#
# use crate::{
# data::{MnistBatch, MnistBatcher},
# model::{Model, ModelConfig},
# };
#
# impl<B: Backend> Model<B> {
# pub fn forward_classification(
# &self,
# images: Tensor<B, 3>,
# targets: Tensor<B, 1, Int>,
# ) -> ClassificationOutput<B> {
# let output = self.forward(images);
# let loss =
# CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
# let loss = CrossEntropyLossConfig::new()
# .init(&output.device())
# .forward(output.clone(), targets.clone());
#
# ClassificationOutput::new(loss, output, targets)
# }
Expand Down

0 comments on commit 0e445a9

Please sign in to comment.