Skip to content

Commit

Permalink
Remove non-kan layers from examples
Browse files Browse the repository at this point in the history
  • Loading branch information
VlaDexa committed May 27, 2024
1 parent c4ab7d8 commit c74c30a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 68 deletions.
10 changes: 8 additions & 2 deletions examples/mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::data::MnistBatcher;
use burn::{
backend::{Autodiff, Wgpu},
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
lr_scheduler::noam::NoamLrSchedulerConfig,
optim::AdamWConfig,
prelude::*,
record::{CompactRecorder, NoStdTrainingRecorder},
Expand Down Expand Up @@ -33,6 +34,7 @@ pub struct KanTrainingConfig {
pub num_workers: usize,
pub optimizer: AdamWConfig,
pub kan_options: KanOptions,
pub lr_scheduler: NoamLrSchedulerConfig,
}

fn create_artifact_dir(artifact_dir: &str) {
Expand All @@ -48,7 +50,11 @@ where
create_artifact_dir(ARTIFACT_DIR);
// Config
let config_optimizer = burn::optim::AdamWConfig::new().with_weight_decay(1e-4);
let config = KanTrainingConfig::new(config_optimizer, KanOptions::new([24 * 22 * 22, 64, 10]));
let config = KanTrainingConfig::new(
config_optimizer,
KanOptions::new([784, 64, 10]),
NoamLrSchedulerConfig::new(1e-4),
);
B::seed(config.seed);

// Data
Expand Down Expand Up @@ -91,7 +97,7 @@ where
.build(
model::Kan::new(&config.kan_options, &device),
config.optimizer.init(),
1e-4,
config.lr_scheduler.init(),
);

let model_trained = learner.fit(dataloader_train, dataloader_test);
Expand Down
76 changes: 10 additions & 66 deletions examples/mnist/model.rs
Original file line number Diff line number Diff line change
@@ -1,94 +1,38 @@
use burn::{
nn::{loss::CrossEntropyLossConfig, BatchNorm, Dropout, DropoutConfig, PaddingConfig2d},
prelude::*,
tensor::backend::AutodiffBackend,
module::Module,
nn::loss::CrossEntropyLossConfig,
tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
},
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};
use burn_efficient_kan::{Kan as EfficientKan, KanOptions};

use crate::data::MnistBatch;

#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: nn::conv::Conv2d<B>,
norm: BatchNorm<B, 2>,
activation: nn::Gelu,
}

impl<B: Backend> ConvBlock<B> {
pub fn new(channels: [usize; 2], kernel_size: [usize; 2], device: &B::Device) -> Self {
let conv = nn::conv::Conv2dConfig::new(channels, kernel_size)
.with_padding(PaddingConfig2d::Valid)
.init(device);
let norm = nn::BatchNormConfig::new(channels[1]).init(device);

Self {
conv,
norm,
activation: nn::Gelu::new(),
}
}

pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(input);
let x = self.norm.forward(x);

self.activation.forward(x)
}
}
#[derive(Module, Debug)]
pub struct Kan<B: Backend> {
conv1: ConvBlock<B>,
conv2: ConvBlock<B>,
conv3: ConvBlock<B>,
dropout: Dropout,
kan: EfficientKan<B>,
activation: nn::PRelu<B>,
}

impl<B: Backend> Kan<B> {
pub fn new(options: &KanOptions, device: &B::Device) -> Self
where
B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack,
{
let conv1 = ConvBlock::new([1, 8], [3, 3], device); // out: [Batch,8,26,26]
let conv2 = ConvBlock::new([8, 16], [3, 3], device); // out: [Batch,16,24x24]
let conv3 = ConvBlock::new([16, 24], [3, 3], device); // out: [Batch,24,22x22]
let dropout = DropoutConfig::new(0.5).init();
let hidden_size = 24 * 22 * 22;
let mut options = options.clone();
options.layers_hidden = [
hidden_size,
options.layers_hidden[1],
options.layers_hidden[2],
];
let kan = EfficientKan::new(&options, device);
let activation = nn::PReluConfig::new().init(device);
let kan = EfficientKan::new(options, device);

Self {
conv1,
conv2,
conv3,
dropout,
kan,
activation,
}
Self { kan }
}
}

impl<B: Backend> Kan<B> {
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, height, width] = input.dims();
let x = input.reshape([batch_size, 1, height, width]).detach();
let x = self.conv1.forward(x);
let x = self.conv2.forward(x);
let x = self.conv3.forward(x);
let [batch_size, channels, height, width] = x.dims();
let x = x.reshape([batch_size, channels * height * width]);
let x = self.dropout.forward(x);
let x = input.reshape([batch_size, height * width]);
let x = self.kan.forward(x);

self.activation.forward(x)
x
}

pub fn forward_classification(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
Expand Down

0 comments on commit c74c30a

Please sign in to comment.