diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 0269d36..0207452 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -46,7 +46,21 @@ pub fn get_cudart_version() -> usize { /// Runs all *VersionCheck functions. pub fn version_check() -> Result<(), CudnnError> { - #[cfg(not(any(feature = "cuda-12030", feature = "cuda-12040", feature = "cuda-12050")))] + #[cfg(feature = "cuda-11040")] + unsafe { + lib().cudnnAdvVersionCheck().result()?; + lib().cudnnCnnVersionCheck().result()?; + lib().cudnnOpsVersionCheck().result()?; + } + #[cfg(any( + feature = "cuda-11050", + feature = "cuda-11060", + feature = "cuda-11070", + feature = "cuda-11080", + feature = "cuda-12000", + feature = "cuda-12010", + feature = "cuda-12020" + ))] unsafe { lib().cudnnAdvInferVersionCheck().result()?; lib().cudnnAdvTrainVersionCheck().result()?; @@ -55,7 +69,7 @@ pub fn version_check() -> Result<(), CudnnError> { lib().cudnnOpsInferVersionCheck().result()?; lib().cudnnOpsTrainVersionCheck().result()?; } - #[cfg(feature = "cuda-12030")] + #[cfg(any(feature = "cuda-12030", feature = "cuda-12050"))] unsafe { lib().cudnnAdvVersionCheck().result()?; lib().cudnnCnnVersionCheck().result()?; @@ -63,16 +77,13 @@ pub fn version_check() -> Result<(), CudnnError> { } #[cfg(feature = "cuda-12040")] unsafe { + lib().cudnnAdvTrainVersionCheck().result()?; + lib().cudnnCnnTrainVersionCheck().result()?; + lib().cudnnOpsTrainVersionCheck().result()?; lib().cudnnAdvInferVersionCheck().result()?; lib().cudnnCnnInferVersionCheck().result()?; lib().cudnnOpsInferVersionCheck().result()?; } - #[cfg(feature = "cuda-12050")] - unsafe { - lib().cudnnAdvVersionCheck().result()?; - lib().cudnnCnnVersionCheck().result()?; - lib().cudnnOpsVersionCheck().result()?; - } Ok(()) }