From e48a6d369f08d39ddb0a42fd1f4bdef07a49a540 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 25 Mar 2024 10:23:26 +0100 Subject: [PATCH] Don't require Vector types to be sized (e.g. slices) --- src/cblas.rs | 104 ++++++++++++++++++++++++-------------------- src/fit.rs | 11 +++-- src/types/vector.rs | 4 +- 3 files changed, 67 insertions(+), 52 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index e0c39b4..9ad98a0 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -6,23 +6,23 @@ use crate::vector::Vector; /// Return the length of `x` as a `i32` value (to use in CBLAS calls). #[inline] -fn len>(x: &T) -> i32 { +fn len + ?Sized>(x: &T) -> i32 { x.len().try_into().expect("Length must fit in `i32`") } #[inline] -fn as_ptr>(x: &T) -> *const F { +fn as_ptr + ?Sized>(x: &T) -> *const F { x.as_slice().as_ptr() } #[inline] -fn as_mut_ptr>(x: &mut T) -> *mut F { +fn as_mut_ptr + ?Sized>(x: &mut T) -> *mut F { x.as_mut_slice().as_mut_ptr() } /// Return the stride of `x` as a `i32` value (to use in CBLAS calls). #[inline] -fn stride>(x: &T) -> i32 { +fn stride + ?Sized>(x: &T) -> i32 { x.stride().try_into().expect("Stride must fit in `i32`") } @@ -34,28 +34,28 @@ pub mod level1 { /// Return the sum of `alpha` and the dot product of `x` and `y`. #[doc(alias = "cblas_sdsdot")] - pub fn sdsdot>(alpha: f32, x: &T, y: &T) -> f32 { + pub fn sdsdot + ?Sized>(alpha: f32, x: &T, y: &T) -> f32 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_sdsdot(len(x), alpha, as_ptr(x), stride(x), as_ptr(y), stride(y)) } } /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_dsdot")] - pub fn dsdot>(x: &T, y: &T) -> f64 { + pub fn dsdot + ?Sized>(x: &T, y: &T) -> f64 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_dsdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_sdot")] - pub fn sdot>(x: &T, y: &T) -> f32 { + pub fn sdot + ?Sized>(x: &T, y: &T) -> f32 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_sdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_ddot")] - pub fn ddot>(x: &T, y: &T) -> f64 { + pub fn ddot + ?Sized>(x: &T, y: &T) -> f64 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_ddot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } @@ -73,7 +73,10 @@ pub mod level1 { /// assert_eq!(cdotu(&x, &x), Complex::new(3., 6.)) /// ``` #[doc(alias = "cblas_cdotu_sub")] - pub fn cdotu>>(x: &T, y: &T) -> Complex { + pub fn cdotu(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotu: Complex = Complex::new(0., 0.); unsafe { @@ -102,7 +105,10 @@ pub mod level1 { /// assert_eq!(cdotc(&x, &x), Complex::new(7., 0.)) /// ``` #[doc(alias = "cblas_cdotc_sub")] - pub fn cdotc>>(x: &T, y: &T) -> Complex { + pub fn cdotc(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotc: Complex = Complex::new(0., 0.); unsafe { @@ -131,7 +137,10 @@ pub mod level1 { /// assert_eq!(zdotu(&x, &x), Complex::new(3., 6.)) /// ``` #[doc(alias = "cblas_zdotu_sub")] - pub fn zdotu>>(x: &T, y: &T) -> Complex { + pub fn zdotu(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotu: Complex = Complex::new(0., 0.); unsafe { @@ -160,7 +169,10 @@ pub mod level1 { /// assert_eq!(zdotc(&x, &x), Complex::new(7., 0.)) /// ``` #[doc(alias = "cblas_zdotc_sub")] - pub fn zdotc>>(x: &T, y: &T) -> Complex { + pub fn zdotc(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotc: Complex = Complex::new(0., 0.); unsafe { @@ -178,27 +190,27 @@ pub mod level1 { /// Return the Euclidean norm of `x`. #[doc(alias = "cblas_snrm2")] - pub fn snrm2>(x: &T) -> f32 { + pub fn snrm2 + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_snrm2(len(x), as_ptr(x), stride(x)) } } /// Return the sum of the absolute values of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_sasum")] - pub fn sasum>(x: &T) -> f32 { + pub fn sasum + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_sasum(len(x), as_ptr(x), stride(x)) } } /// Return the Euclidean norm of `x`. #[doc(alias = "cblas_dnrm2")] - pub fn dnrm2>(x: &T) -> f64 { + pub fn dnrm2 + ?Sized>(x: &T) -> f64 { unsafe { sys::cblas_dnrm2(len(x), as_ptr(x), stride(x)) } } /// Return the sum of the absolute values of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_dasum")] - pub fn dasum>(x: &T) -> f64 { + pub fn dasum + ?Sized>(x: &T) -> f64 { unsafe { sys::cblas_dasum(len(x), as_ptr(x), stride(x)) } } @@ -214,7 +226,7 @@ pub mod level1 { /// assert_eq!(scnrm2(&x), 7f32.sqrt()) /// ``` #[doc(alias = "cblas_scnrm2")] - pub fn scnrm2>>(x: &T) -> f32 { + pub fn scnrm2> + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_scnrm2(len(x), as_ptr(x) as *const _, stride(x)) } } @@ -222,7 +234,7 @@ pub mod level1 { /// Return the sum of the modulus of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_scasum")] - pub fn scasum>>(x: &T) -> f32 { + pub fn scasum> + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_scasum(len(x), as_ptr(x) as *const _, stride(x)) } } @@ -238,7 +250,7 @@ pub mod level1 { /// assert_eq!(dznrm2(&x), 7f64.sqrt()) /// ``` #[doc(alias = "cblas_dznrm2")] - pub fn dznrm2>>(x: &T) -> f64 { + pub fn dznrm2> + ?Sized>(x: &T) -> f64 { unsafe { ::sys::cblas_dznrm2(len(x), as_ptr(x) as *const _, stride(x)) } } @@ -246,53 +258,53 @@ pub mod level1 { /// Return the sum of the modulus of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_dzasum")] - pub fn dzasum>>(x: &T) -> f64 { + pub fn dzasum> + ?Sized>(x: &T) -> f64 { unsafe { ::sys::cblas_dzasum(len(x), as_ptr(x) as *const _, stride(x)) } } /// Return the index of the element with maximum absolute value. #[doc(alias = "cblas_isamax")] - pub fn isamax>(x: &T) -> usize { + pub fn isamax + ?Sized>(x: &T) -> usize { unsafe { ::sys::cblas_isamax(len(x), as_ptr(x), stride(x)) } } /// Return the index of the element with maximum absolute value. #[doc(alias = "cblas_idamax")] - pub fn idamax>(x: &T) -> usize { + pub fn idamax + ?Sized>(x: &T) -> usize { unsafe { ::sys::cblas_idamax(len(x), as_ptr(x), stride(x)) } } #[cfg(feature = "complex")] /// Return the index of the element with maximum modulus. #[doc(alias = "cblas_icamax")] - pub fn icamax>>(x: &T) -> usize { + pub fn icamax> + ?Sized>(x: &T) -> usize { unsafe { ::sys::cblas_icamax(len(x), as_ptr(x) as *const _, stride(x)) } } #[cfg(feature = "complex")] /// Return the index of the element with maximum modulus. #[doc(alias = "cblas_izamax")] - pub fn izamax>>(x: &T) -> usize { + pub fn izamax> + ?Sized>(x: &T) -> usize { unsafe { ::sys::cblas_izamax(len(x), as_ptr(x) as *const _, stride(x)) } } /// Swap vectors `x` and `y`. #[doc(alias = "cblas_sswap")] - pub fn sswap>(x: &mut T, y: &mut T) { + pub fn sswap + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_sswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// Copy the content of `x` into `y`. #[doc(alias = "cblas_scopy")] - pub fn scopy>(x: &T, y: &mut T) { + pub fn scopy + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_scopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_saxpy")] - pub fn saxpy>(alpha: f32, x: &T, y: &mut T) { + pub fn saxpy + ?Sized>(alpha: f32, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_saxpy( @@ -308,21 +320,21 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_dswap")] - pub fn dswap>(x: &mut T, y: &mut T) { + pub fn dswap + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// Copy the content of `x` into `y`. #[doc(alias = "cblas_dcopy")] - pub fn dcopy>(x: &T, y: &mut T) { + pub fn dcopy + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dcopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_daxpy")] - pub fn daxpy>(alpha: f64, x: &T, y: &mut T) { + pub fn daxpy + ?Sized>(alpha: f64, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_daxpy( @@ -339,7 +351,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Swap vectors `x` and `y`. #[doc(alias = "cblas_cswap")] - pub fn cswap>>(x: &mut T, y: &mut T) { + pub fn cswap> + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_cswap( @@ -355,7 +367,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Copy the content of `x` into `y`. #[doc(alias = "cblas_ccopy")] - pub fn ccopy>>(x: &T, y: &mut T) { + pub fn ccopy> + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_ccopy( @@ -371,7 +383,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_caxpy")] - pub fn caxpy>>(alpha: &Complex, x: &T, y: &mut T) { + pub fn caxpy> + ?Sized>(alpha: &Complex, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_caxpy( @@ -388,7 +400,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Swap vectors `x` and `y`. #[doc(alias = "cblas_zswap")] - pub fn zswap>>(x: &mut T, y: &mut T) { + pub fn zswap> + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_zswap( @@ -404,7 +416,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Copy the content of `x` into `y`. #[doc(alias = "cblas_zcopy")] - pub fn zcopy>>(x: &T, y: &mut T) { + pub fn zcopy> + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_zcopy( @@ -420,7 +432,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_zaxpy")] - pub fn zaxpy>>(alpha: &Complex, x: &T, y: &mut T) { + pub fn zaxpy> + ?Sized>(alpha: &Complex, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_zaxpy( @@ -544,7 +556,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_srot")] - pub fn srot>(x: &mut T, y: &mut T, c: f32, s: f32) { + pub fn srot + ?Sized>(x: &mut T, y: &mut T, c: f32, s: f32) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_srot( @@ -566,7 +578,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_srotm")] - pub fn srotm>(x: &mut T, y: &mut T, h: H) { + pub fn srotm + ?Sized>(x: &mut T, y: &mut T, h: H) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -664,7 +676,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_drot")] - pub fn drot>(x: &mut T, y: &mut T, c: f64, s: f64) { + pub fn drot + ?Sized>(x: &mut T, y: &mut T, c: f64, s: f64) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { ::sys::cblas_drot( @@ -686,7 +698,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_drotm")] - pub fn drotm>(x: &mut T, y: &mut T, h: H) { + pub fn drotm + ?Sized>(x: &mut T, y: &mut T, h: H) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -708,20 +720,20 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_sscal")] - pub fn sscal>(alpha: f32, x: &mut T) { + pub fn sscal + ?Sized>(alpha: f32, x: &mut T) { unsafe { ::sys::cblas_sscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_dscal")] - pub fn dscal>(alpha: f64, x: &mut T) { + pub fn dscal + ?Sized>(alpha: f64, x: &mut T) { unsafe { ::sys::cblas_dscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_cscal")] - pub fn cscal>>(alpha: &Complex, x: &mut T) { + pub fn cscal> + ?Sized>(alpha: &Complex, x: &mut T) { unsafe { ::sys::cblas_cscal( len(x), @@ -735,7 +747,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_zscal")] - pub fn zscal>>(alpha: &Complex, x: &mut T) { + pub fn zscal> + ?Sized>(alpha: &Complex, x: &mut T) { unsafe { ::sys::cblas_zscal( len(x), @@ -749,14 +761,14 @@ pub mod level1 { #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_csscal")] - pub fn csscal>>(alpha: f32, x: &mut T) { + pub fn csscal> + ?Sized>(alpha: f32, x: &mut T) { unsafe { ::sys::cblas_csscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } #[cfg(feature = "complex")] /// Multiple each element of a matrix/vector by a constant. #[doc(alias = "cblas_zdscal")] - pub fn zdscal>>(alpha: f64, x: &mut T) { + pub fn zdscal> + ?Sized>(alpha: f64, x: &mut T) { unsafe { ::sys::cblas_zdscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } } diff --git a/src/fit.rs b/src/fit.rs index 58c6d32..a9c82db 100644 --- a/src/fit.rs +++ b/src/fit.rs @@ -40,7 +40,10 @@ use crate::{ /// # Ok::<(), rgsl::Value>(()) /// ``` #[doc(alias = "gsl_fit_linear")] -pub fn linear>(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, f64), Value> { +pub fn linear(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, f64), Value> +where + T: Vector + ?Sized, +{ check_equal_len(x, y)?; let mut c0 = 0.; let mut c1 = 0.; @@ -81,7 +84,7 @@ pub fn linear>(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, /// /// Returns `(c0, c1, cov00, cov01, cov11, chisq)`. #[doc(alias = "gsl_fit_wlinear")] -pub fn wlinear>( +pub fn wlinear + ?Sized>( x: &T, w: &T, y: &T, @@ -143,7 +146,7 @@ pub fn linear_est( /// /// Returns `(c1, cov11, sumsq)`. #[doc(alias = "gsl_fit_mul")] -pub fn mul>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { +pub fn mul + ?Sized>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { check_equal_len(x, y)?; let mut c1 = 0.; let mut cov11 = 0.; @@ -165,7 +168,7 @@ pub fn mul>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { /// Returns `(c1, cov11, sumsq)`. #[doc(alias = "gsl_fit_wmul")] -pub fn wmul>(x: &T, w: &T, y: &T) -> Result<(f64, f64, f64), Value> { +pub fn wmul + ?Sized>(x: &T, w: &T, y: &T) -> Result<(f64, f64, f64), Value> { check_equal_len(x, y)?; check_equal_len(x, w)?; let mut c1 = 0.; diff --git a/src/types/vector.rs b/src/types/vector.rs index 3bcecbc..87cf1ae 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -612,7 +612,7 @@ macro_rules! impl_AsRef { ($ty: ty) => { impl Vector<$ty> for T where - T: AsRef<[$ty]> + AsMut<[$ty]>, + T: AsRef<[$ty]> + AsMut<[$ty]> + ?Sized, { fn len(&self) -> usize { self.as_ref().len() @@ -640,7 +640,7 @@ impl_AsRef!(Complex); #[inline] pub(crate) fn check_equal_len(x: &T, y: &T) -> Result<(), Value> where - T: Vector, + T: Vector + ?Sized, { if x.len() != y.len() { return Err(Value::Invalid);