-
Notifications
You must be signed in to change notification settings - Fork 449
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
360 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
use backend_comparison::persistence::save; | ||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; | ||
use burn_common::benchmark::{run_benchmark, Benchmark}; | ||
use cubecl::client::SyncType; | ||
|
||
// Files retrieved during build to avoid reimplementing ResNet for benchmarks | ||
mod block { | ||
extern crate alloc; | ||
include!(concat!(env!("OUT_DIR"), "/block.rs")); | ||
} | ||
|
||
mod model { | ||
include!(concat!(env!("OUT_DIR"), "/resnet.rs")); | ||
} | ||
|
||
pub struct ResNetBenchmark<B: Backend> { | ||
shape: Shape<4>, | ||
device: B::Device, | ||
} | ||
|
||
impl<B: Backend> Benchmark for ResNetBenchmark<B> { | ||
type Args = (model::ResNet<B>, Tensor<B, 4>); | ||
|
||
fn name(&self) -> String { | ||
"resnet50".into() | ||
} | ||
|
||
fn shapes(&self) -> Vec<Vec<usize>> { | ||
vec![self.shape.dims.into()] | ||
} | ||
|
||
fn execute(&self, (model, input): Self::Args) { | ||
let _out = model.forward(input); | ||
} | ||
|
||
fn prepare(&self) -> Self::Args { | ||
// 1k classes like ImageNet | ||
let model = model::ResNet::resnet50(1000, &self.device); | ||
let input = Tensor::random(self.shape.clone(), Distribution::Default, &self.device); | ||
|
||
(model, input) | ||
} | ||
|
||
fn sync(&self) { | ||
B::sync(&self.device, SyncType::Wait) | ||
} | ||
} | ||
|
||
#[allow(dead_code)] | ||
fn bench<B: Backend>( | ||
device: &B::Device, | ||
feature_name: &str, | ||
url: Option<&str>, | ||
token: Option<&str>, | ||
) { | ||
let benchmark = ResNetBenchmark::<B> { | ||
shape: [1, 3, 224, 224].into(), | ||
device: device.clone(), | ||
}; | ||
|
||
save::<B>( | ||
vec![run_benchmark(benchmark)], | ||
device, | ||
feature_name, | ||
url, | ||
token, | ||
) | ||
.unwrap(); | ||
} | ||
|
||
fn main() { | ||
backend_comparison::bench_on_backend!(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
use std::env; | ||
use std::fs; | ||
use std::path::Path; | ||
use std::process::Command; | ||
|
||
const MODELS_DIR: &str = "/tmp/models"; | ||
const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git"; | ||
|
||
// Patch resnet code (remove pretrained feature code) | ||
const PATCH: &str = r#"diff --git a/resnet-burn/resnet/src/resnet.rs b/resnet-burn/resnet/src/resnet.rs | ||
index e7f8787..3967049 100644 | ||
--- a/resnet-burn/resnet/src/resnet.rs | ||
+++ b/resnet-burn/resnet/src/resnet.rs | ||
@@ -12,13 +12,6 @@ use burn::{ | ||
use super::block::{LayerBlock, LayerBlockConfig}; | ||
-#[cfg(feature = "pretrained")] | ||
-use { | ||
- super::weights::{self, WeightsMeta}, | ||
- burn::record::{FullPrecisionSettings, Recorder, RecorderError}, | ||
- burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}, | ||
-}; | ||
- | ||
// ResNet residual layer block configs | ||
const RESNET18_BLOCKS: [usize; 4] = [2, 2, 2, 2]; | ||
const RESNET34_BLOCKS: [usize; 4] = [3, 4, 6, 3]; | ||
@@ -77,29 +70,6 @@ impl<B: Backend> ResNet<B> { | ||
ResNetConfig::new(RESNET18_BLOCKS, num_classes, 1).init(device) | ||
} | ||
- /// ResNet-18 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) | ||
- /// with pre-trained weights. | ||
- /// | ||
- /// # Arguments | ||
- /// | ||
- /// * `weights`: Pre-trained weights to load. | ||
- /// * `device` - Device to create the module on. | ||
- /// | ||
- /// # Returns | ||
- /// | ||
- /// A ResNet-18 module with pre-trained weights. | ||
- #[cfg(feature = "pretrained")] | ||
- pub fn resnet18_pretrained( | ||
- weights: weights::ResNet18, | ||
- device: &Device<B>, | ||
- ) -> Result<Self, RecorderError> { | ||
- let weights = weights.weights(); | ||
- let record = Self::load_weights_record(&weights, device)?; | ||
- let model = ResNet::<B>::resnet18(weights.num_classes, device).load_record(record); | ||
- | ||
- Ok(model) | ||
- } | ||
- | ||
/// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). | ||
/// | ||
/// # Arguments | ||
@@ -114,29 +84,6 @@ impl<B: Backend> ResNet<B> { | ||
ResNetConfig::new(RESNET34_BLOCKS, num_classes, 1).init(device) | ||
} | ||
- /// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) | ||
- /// with pre-trained weights. | ||
- /// | ||
- /// # Arguments | ||
- /// | ||
- /// * `weights`: Pre-trained weights to load. | ||
- /// * `device` - Device to create the module on. | ||
- /// | ||
- /// # Returns | ||
- /// | ||
- /// A ResNet-34 module with pre-trained weights. | ||
- #[cfg(feature = "pretrained")] | ||
- pub fn resnet34_pretrained( | ||
- weights: weights::ResNet34, | ||
- device: &Device<B>, | ||
- ) -> Result<Self, RecorderError> { | ||
- let weights = weights.weights(); | ||
- let record = Self::load_weights_record(&weights, device)?; | ||
- let model = ResNet::<B>::resnet34(weights.num_classes, device).load_record(record); | ||
- | ||
- Ok(model) | ||
- } | ||
- | ||
/// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). | ||
/// | ||
/// # Arguments | ||
@@ -151,29 +98,6 @@ impl<B: Backend> ResNet<B> { | ||
ResNetConfig::new(RESNET50_BLOCKS, num_classes, 4).init(device) | ||
} | ||
- /// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) | ||
- /// with pre-trained weights. | ||
- /// | ||
- /// # Arguments | ||
- /// | ||
- /// * `weights`: Pre-trained weights to load. | ||
- /// * `device` - Device to create the module on. | ||
- /// | ||
- /// # Returns | ||
- /// | ||
- /// A ResNet-50 module with pre-trained weights. | ||
- #[cfg(feature = "pretrained")] | ||
- pub fn resnet50_pretrained( | ||
- weights: weights::ResNet50, | ||
- device: &Device<B>, | ||
- ) -> Result<Self, RecorderError> { | ||
- let weights = weights.weights(); | ||
- let record = Self::load_weights_record(&weights, device)?; | ||
- let model = ResNet::<B>::resnet50(weights.num_classes, device).load_record(record); | ||
- | ||
- Ok(model) | ||
- } | ||
- | ||
/// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). | ||
/// | ||
/// # Arguments | ||
@@ -188,29 +112,6 @@ impl<B: Backend> ResNet<B> { | ||
ResNetConfig::new(RESNET101_BLOCKS, num_classes, 4).init(device) | ||
} | ||
- /// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) | ||
- /// with pre-trained weights. | ||
- /// | ||
- /// # Arguments | ||
- /// | ||
- /// * `weights`: Pre-trained weights to load. | ||
- /// * `device` - Device to create the module on. | ||
- /// | ||
- /// # Returns | ||
- /// | ||
- /// A ResNet-101 module with pre-trained weights. | ||
- #[cfg(feature = "pretrained")] | ||
- pub fn resnet101_pretrained( | ||
- weights: weights::ResNet101, | ||
- device: &Device<B>, | ||
- ) -> Result<Self, RecorderError> { | ||
- let weights = weights.weights(); | ||
- let record = Self::load_weights_record(&weights, device)?; | ||
- let model = ResNet::<B>::resnet101(weights.num_classes, device).load_record(record); | ||
- | ||
- Ok(model) | ||
- } | ||
- | ||
/// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). | ||
/// | ||
/// # Arguments | ||
@@ -225,29 +126,6 @@ impl<B: Backend> ResNet<B> { | ||
ResNetConfig::new(RESNET152_BLOCKS, num_classes, 4).init(device) | ||
} | ||
- /// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) | ||
- /// with pre-trained weights. | ||
- /// | ||
- /// # Arguments | ||
- /// | ||
- /// * `weights`: Pre-trained weights to load. | ||
- /// * `device` - Device to create the module on. | ||
- /// | ||
- /// # Returns | ||
- /// | ||
- /// A ResNet-152 module with pre-trained weights. | ||
- #[cfg(feature = "pretrained")] | ||
- pub fn resnet152_pretrained( | ||
- weights: weights::ResNet152, | ||
- device: &Device<B>, | ||
- ) -> Result<Self, RecorderError> { | ||
- let weights = weights.weights(); | ||
- let record = Self::load_weights_record(&weights, device)?; | ||
- let model = ResNet::<B>::resnet152(weights.num_classes, device).load_record(record); | ||
- | ||
- Ok(model) | ||
- } | ||
- | ||
/// Re-initialize the last layer with the specified number of output classes. | ||
pub fn with_classes(mut self, num_classes: usize) -> Self { | ||
let [d_input, _d_output] = self.fc.weight.dims(); | ||
@@ -256,32 +134,6 @@ impl<B: Backend> ResNet<B> { | ||
} | ||
} | ||
-#[cfg(feature = "pretrained")] | ||
-impl<B: Backend> ResNet<B> { | ||
- /// Load specified pre-trained PyTorch weights as a record. | ||
- fn load_weights_record( | ||
- weights: &weights::Weights, | ||
- device: &Device<B>, | ||
- ) -> Result<ResNetRecord<B>, RecorderError> { | ||
- // Download torch weights | ||
- let torch_weights = weights.download().map_err(|err| { | ||
- RecorderError::Unknown(format!("Could not download weights.\nError: {err}")) | ||
- })?; | ||
- | ||
- // Load weights from torch state_dict | ||
- let load_args = LoadArgs::new(torch_weights) | ||
- // Map *.downsample.0.* -> *.downsample.conv.* | ||
- .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2") | ||
- // Map *.downsample.1.* -> *.downsample.bn.* | ||
- .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2") | ||
- // Map layer[i].[j].* -> layer[i].blocks.[j].* | ||
- .with_key_remap("(layer[1-4])\\.([0-9]+)\\.(.+)", "$1.blocks.$2.$3"); | ||
- let record = PyTorchFileRecorder::<FullPrecisionSettings>::new().load(load_args, device)?; | ||
- | ||
- Ok(record) | ||
- } | ||
-} | ||
- | ||
/// [ResNet](ResNet) configuration. | ||
struct ResNetConfig { | ||
conv1: Conv2dConfig, | ||
"#; | ||
|
||
fn run<F>(name: &str, mut configure: F) | ||
where | ||
F: FnMut(&mut Command) -> &mut Command, | ||
{ | ||
let mut command = Command::new(name); | ||
let configured = configure(&mut command); | ||
println!("Executing {:?}", configured); | ||
if !configured.status().unwrap().success() { | ||
panic!("failed to execute {:?}", configured); | ||
} | ||
println!("Command {:?} finished successfully", configured); | ||
} | ||
|
||
fn main() { | ||
// Checkout ResNet code from models repo | ||
let models_dir = Path::new(MODELS_DIR); | ||
if !models_dir.join(".git").exists() { | ||
run("git", |command| { | ||
command | ||
.arg("clone") | ||
.arg("--depth=1") | ||
.arg("--no-checkout") | ||
.arg(MODELS_REPO) | ||
.arg(MODELS_DIR) | ||
}); | ||
|
||
run("git", |command| { | ||
command | ||
.current_dir(models_dir) | ||
.arg("sparse-checkout") | ||
.arg("set") | ||
.arg("resnet-burn") | ||
}); | ||
|
||
run("git", |command| { | ||
command.current_dir(models_dir).arg("checkout") | ||
}); | ||
|
||
let patch_file = models_dir.join("benchmark.patch"); | ||
|
||
fs::write(&patch_file, PATCH).expect("should write to file successfully"); | ||
|
||
// Apply patch | ||
run("git", |command| { | ||
command | ||
.current_dir(models_dir) | ||
.arg("apply") | ||
.arg(patch_file.to_str().unwrap()) | ||
}); | ||
} | ||
|
||
// Copy contents to output dir | ||
let out_dir = env::var("OUT_DIR").unwrap(); | ||
let source_path = models_dir.join("resnet-burn").join("resnet").join("src"); | ||
let dest_path = Path::new(&out_dir); | ||
|
||
for file in fs::read_dir(source_path).unwrap() { | ||
let source_file = file.unwrap().path(); | ||
let dest_file = dest_path.join(source_file.file_name().unwrap()); | ||
fs::copy(source_file, dest_file).expect("should copy file successfully"); | ||
} | ||
|
||
// Delete cloned repository contents | ||
fs::remove_dir_all(models_dir.join(".git")).unwrap(); | ||
fs::remove_dir_all(models_dir).unwrap(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters