diff --git a/build.rs b/build.rs index f3440c2..e106499 100644 --- a/build.rs +++ b/build.rs @@ -30,27 +30,37 @@ fn link_cuda() { #[cfg(feature = "driver")] println!("cargo:rustc-link-lib=dylib=cuda"); - #[cfg(feature = "nvrtc")] - println!("cargo:rustc-link-lib=dylib=nvrtc"); - #[cfg(feature = "curand")] - println!("cargo:rustc-link-lib=dylib=curand"); #[cfg(feature = "nccl")] println!("cargo:rustc-link-lib=dylib=nccl"); #[cfg(feature = "static-linking")] { - #[cfg(feature = "cublas")] println!("cargo:rustc-link-lib=dylib=stdc++"); + #[cfg(any(feature = "cublas", feature = "cublaslt"))] { + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=static=cublasLt_static"); + } #[cfg(feature = "cublas")] println!("cargo:rustc-link-lib=static=cublas_static"); - #[cfg(feature = "cublas")] - println!("cargo:rustc-link-lib=static=cublasLt_static"); + #[cfg(feature = "curand")] { + println!("cargo:rustc-link-lib=dylib=culibos"); + println!("cargo:rustc-link-lib=static=curand_static"); + } + #[cfg(feature = "nvrtc")] { + println!("cargo:rustc-link-lib=static=nvrtc_static"); + println!("cargo:rustc-link-lib=static=nvptxcompiler_static"); + println!("cargo:rustc-link-lib=static=nvrtc-builtins_static"); + } } #[cfg(not(feature = "static-linking"))] { + #[cfg(feature = "nvrtc")] + println!("cargo:rustc-link-lib=dylib=nvrtc"); + #[cfg(feature = "curand")] + println!("cargo:rustc-link-lib=dylib=curand"); #[cfg(feature = "cublas")] println!("cargo:rustc-link-lib=dylib=cublas"); - #[cfg(feature = "cublas")] + #[cfg(any(feature = "cublas", feature = "cublaslt"))] println!("cargo:rustc-link-lib=dylib=cublasLt"); }