Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Feb 26, 2024
1 parent 4970669 commit 36aacf0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 19 deletions.
24 changes: 6 additions & 18 deletions src/cudnn/safe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,32 +207,20 @@ mod tests {
// Create input, filter and output tensors
let x = dev.htod_copy(vec![1.0f32; 32 * 3 * 64 * 64 * 64]).unwrap();
let x_desc = cudnn.create_nd_tensor::<f32>(
&[32, 3, 64, 64, 64],
&[
3 * 64 * 64 * 64,
64 * 64 * 64,
64 * 64,
64,
1
]
&[32, 3, 64, 64, 64],
&[3 * 64 * 64 * 64, 64 * 64 * 64, 64 * 64, 64, 1],
)?;
let filter = dev.htod_copy(vec![1.0f32; 32 * 3 * 4 * 4 * 4]).unwrap();
let filter_desc = cudnn.create_nd_filter::<f32>(
cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
&[32, 3, 4, 4, 4],
)?;
let mut y = dev.alloc_zeros::<f32>(32 * 32 * 61 * 61 * 61).unwrap();
let y_desc = cudnn.create_nd_tensor::<f32>(
&[32, 32, 61, 61, 61],
&[
32 * 61 * 61 * 61,
61 * 61 * 61,
61 * 61,
61,
1
]
&[32, 32, 61, 61, 61],
&[32 * 61 * 61 * 61, 61 * 61 * 61, 61 * 61, 61, 1],
)?;

{
let op = ConvForward {
conv: &conv,
Expand Down
2 changes: 1 addition & 1 deletion src/driver/safe/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::driver::{result, sys};
use super::alloc::DeviceRepr;
use super::core::{CudaDevice, CudaFunction, CudaModule, CudaStream};

use std::{vec::Vec, sync::Arc};
use std::{sync::Arc, vec::Vec};

impl CudaDevice {
/// Whether a module and function are currently loaded into the device.
Expand Down

0 comments on commit 36aacf0

Please sign in to comment.