From 146785108f6ffff8a4a3465a7f8da3ce957f7bc3 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 20 Aug 2024 10:52:11 -0400 Subject: [PATCH] #285 fixing version check in cudnn for 11.4 (#288) --- src/cudnn/result.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 0269d361..02074524 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -46,7 +46,21 @@ pub fn get_cudart_version() -> usize { /// Runs all *VersionCheck functions. pub fn version_check() -> Result<(), CudnnError> { - #[cfg(not(any(feature = "cuda-12030", feature = "cuda-12040", feature = "cuda-12050")))] + #[cfg(feature = "cuda-11040")] + unsafe { + lib().cudnnAdvVersionCheck().result()?; + lib().cudnnCnnVersionCheck().result()?; + lib().cudnnOpsVersionCheck().result()?; + } + #[cfg(any( + feature = "cuda-11050", + feature = "cuda-11060", + feature = "cuda-11070", + feature = "cuda-11080", + feature = "cuda-12000", + feature = "cuda-12010", + feature = "cuda-12020" + ))] unsafe { lib().cudnnAdvInferVersionCheck().result()?; lib().cudnnAdvTrainVersionCheck().result()?; @@ -55,7 +69,7 @@ pub fn version_check() -> Result<(), CudnnError> { lib().cudnnOpsInferVersionCheck().result()?; lib().cudnnOpsTrainVersionCheck().result()?; } - #[cfg(feature = "cuda-12030")] + #[cfg(any(feature = "cuda-12030", feature = "cuda-12050"))] unsafe { lib().cudnnAdvVersionCheck().result()?; lib().cudnnCnnVersionCheck().result()?; @@ -63,16 +77,13 @@ pub fn version_check() -> Result<(), CudnnError> { } #[cfg(feature = "cuda-12040")] unsafe { + lib().cudnnAdvTrainVersionCheck().result()?; + lib().cudnnCnnTrainVersionCheck().result()?; + lib().cudnnOpsTrainVersionCheck().result()?; lib().cudnnAdvInferVersionCheck().result()?; lib().cudnnCnnInferVersionCheck().result()?; lib().cudnnOpsInferVersionCheck().result()?; } - #[cfg(feature = "cuda-12050")] - unsafe { - lib().cudnnAdvVersionCheck().result()?; - lib().cudnnCnnVersionCheck().result()?; - lib().cudnnOpsVersionCheck().result()?; - } Ok(()) }