Skip to content

Commit

Permalink
Don't require Vector types to be sized (e.g. slices)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Mar 25, 2024
1 parent 3e5c78e commit e48a6d3
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 52 deletions.
104 changes: 58 additions & 46 deletions src/cblas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ use crate::vector::Vector;

/// Return the length of `x` as a `i32` value (to use in CBLAS calls).
#[inline]
fn len<F, T: Vector<F>>(x: &T) -> i32 {
fn len<F, T: Vector<F> + ?Sized>(x: &T) -> i32 {
x.len().try_into().expect("Length must fit in `i32`")
}

#[inline]
fn as_ptr<F, T: Vector<F>>(x: &T) -> *const F {
fn as_ptr<F, T: Vector<F> + ?Sized>(x: &T) -> *const F {
x.as_slice().as_ptr()
}

#[inline]
fn as_mut_ptr<F, T: Vector<F>>(x: &mut T) -> *mut F {
fn as_mut_ptr<F, T: Vector<F> + ?Sized>(x: &mut T) -> *mut F {
x.as_mut_slice().as_mut_ptr()
}

/// Return the stride of `x` as a `i32` value (to use in CBLAS calls).
#[inline]
fn stride<F, T: Vector<F>>(x: &T) -> i32 {
fn stride<F, T: Vector<F> + ?Sized>(x: &T) -> i32 {
x.stride().try_into().expect("Stride must fit in `i32`")
}

Expand All @@ -34,28 +34,28 @@ pub mod level1 {

/// Return the sum of `alpha` and the dot product of `x` and `y`.
#[doc(alias = "cblas_sdsdot")]
pub fn sdsdot<T: Vector<f32>>(alpha: f32, x: &T, y: &T) -> f32 {
pub fn sdsdot<T: Vector<f32> + ?Sized>(alpha: f32, x: &T, y: &T) -> f32 {
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
unsafe { sys::cblas_sdsdot(len(x), alpha, as_ptr(x), stride(x), as_ptr(y), stride(y)) }
}

/// Return the dot product of `x` and `y`.
#[doc(alias = "cblas_dsdot")]
pub fn dsdot<T: Vector<f32>>(x: &T, y: &T) -> f64 {
pub fn dsdot<T: Vector<f32> + ?Sized>(x: &T, y: &T) -> f64 {
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
unsafe { sys::cblas_dsdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) }
}

/// Return the dot product of `x` and `y`.
#[doc(alias = "cblas_sdot")]
pub fn sdot<T: Vector<f32>>(x: &T, y: &T) -> f32 {
pub fn sdot<T: Vector<f32> + ?Sized>(x: &T, y: &T) -> f32 {
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
unsafe { sys::cblas_sdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) }
}

/// Return the dot product of `x` and `y`.
#[doc(alias = "cblas_ddot")]
pub fn ddot<T: Vector<f64>>(x: &T, y: &T) -> f64 {
pub fn ddot<T: Vector<f64> + ?Sized>(x: &T, y: &T) -> f64 {
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
unsafe { sys::cblas_ddot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) }
}
Expand All @@ -73,7 +73,10 @@ pub mod level1 {
/// assert_eq!(cdotu(&x, &x), Complex::new(3., 6.))
/// ```
#[doc(alias = "cblas_cdotu_sub")]
pub fn cdotu<T: Vector<Complex<f32>>>(x: &T, y: &T) -> Complex<f32> {
pub fn cdotu<T>(x: &T, y: &T) -> Complex<f32>
where
T: Vector<Complex<f32>> + ?Sized,
{
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
let mut dotu: Complex<f32> = Complex::new(0., 0.);
unsafe {
Expand Down Expand Up @@ -102,7 +105,10 @@ pub mod level1 {
/// assert_eq!(cdotc(&x, &x), Complex::new(7., 0.))
/// ```
#[doc(alias = "cblas_cdotc_sub")]
pub fn cdotc<T: Vector<Complex<f32>>>(x: &T, y: &T) -> Complex<f32> {
pub fn cdotc<T>(x: &T, y: &T) -> Complex<f32>
where
T: Vector<Complex<f32>> + ?Sized,
{
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
let mut dotc: Complex<f32> = Complex::new(0., 0.);
unsafe {
Expand Down Expand Up @@ -131,7 +137,10 @@ pub mod level1 {
/// assert_eq!(zdotu(&x, &x), Complex::new(3., 6.))
/// ```
#[doc(alias = "cblas_zdotu_sub")]
pub fn zdotu<T: Vector<Complex<f64>>>(x: &T, y: &T) -> Complex<f64> {
pub fn zdotu<T>(x: &T, y: &T) -> Complex<f64>
where
T: Vector<Complex<f64>> + ?Sized,
{
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
let mut dotu: Complex<f64> = Complex::new(0., 0.);
unsafe {
Expand Down Expand Up @@ -160,7 +169,10 @@ pub mod level1 {
/// assert_eq!(zdotc(&x, &x), Complex::new(7., 0.))
/// ```
#[doc(alias = "cblas_zdotc_sub")]
pub fn zdotc<T: Vector<Complex<f64>>>(x: &T, y: &T) -> Complex<f64> {
pub fn zdotc<T>(x: &T, y: &T) -> Complex<f64>
where
T: Vector<Complex<f64>> + ?Sized,
{
check_equal_len(x, y).expect("The length of `x` and `y` must be equal");
let mut dotc: Complex<f64> = Complex::new(0., 0.);
unsafe {
Expand All @@ -178,27 +190,27 @@ pub mod level1 {

/// Return the Euclidean norm of `x`.
#[doc(alias = "cblas_snrm2")]
pub fn snrm2<T: Vector<f32>>(x: &T) -> f32 {
pub fn snrm2<T: Vector<f32> + ?Sized>(x: &T) -> f32 {
unsafe { sys::cblas_snrm2(len(x), as_ptr(x), stride(x)) }
}

/// Return the sum of the absolute values of the elements of `x`
/// (i.e., its L¹-norm).
#[doc(alias = "cblas_sasum")]
pub fn sasum<T: Vector<f32>>(x: &T) -> f32 {
pub fn sasum<T: Vector<f32> + ?Sized>(x: &T) -> f32 {
unsafe { sys::cblas_sasum(len(x), as_ptr(x), stride(x)) }
}

/// Return the Euclidean norm of `x`.
#[doc(alias = "cblas_dnrm2")]
pub fn dnrm2<T: Vector<f64>>(x: &T) -> f64 {
pub fn dnrm2<T: Vector<f64> + ?Sized>(x: &T) -> f64 {
unsafe { sys::cblas_dnrm2(len(x), as_ptr(x), stride(x)) }
}

/// Return the sum of the absolute values of the elements of `x`
/// (i.e., its L¹-norm).
#[doc(alias = "cblas_dasum")]
pub fn dasum<T: Vector<f64>>(x: &T) -> f64 {
pub fn dasum<T: Vector<f64> + ?Sized>(x: &T) -> f64 {
unsafe { sys::cblas_dasum(len(x), as_ptr(x), stride(x)) }
}

Expand All @@ -214,15 +226,15 @@ pub mod level1 {
/// assert_eq!(scnrm2(&x), 7f32.sqrt())
/// ```
#[doc(alias = "cblas_scnrm2")]
pub fn scnrm2<T: Vector<Complex<f32>>>(x: &T) -> f32 {
pub fn scnrm2<T: Vector<Complex<f32>> + ?Sized>(x: &T) -> f32 {
unsafe { sys::cblas_scnrm2(len(x), as_ptr(x) as *const _, stride(x)) }
}

#[cfg(feature = "complex")]
/// Return the sum of the modulus of the elements of `x`
/// (i.e., its L¹-norm).
#[doc(alias = "cblas_scasum")]
pub fn scasum<T: Vector<Complex<f32>>>(x: &T) -> f32 {
pub fn scasum<T: Vector<Complex<f32>> + ?Sized>(x: &T) -> f32 {
unsafe { sys::cblas_scasum(len(x), as_ptr(x) as *const _, stride(x)) }
}

Expand All @@ -238,61 +250,61 @@ pub mod level1 {
/// assert_eq!(dznrm2(&x), 7f64.sqrt())
/// ```
#[doc(alias = "cblas_dznrm2")]
pub fn dznrm2<T: Vector<Complex<f64>>>(x: &T) -> f64 {
pub fn dznrm2<T: Vector<Complex<f64>> + ?Sized>(x: &T) -> f64 {
unsafe { ::sys::cblas_dznrm2(len(x), as_ptr(x) as *const _, stride(x)) }
}

#[cfg(feature = "complex")]
/// Return the sum of the modulus of the elements of `x`
/// (i.e., its L¹-norm).
#[doc(alias = "cblas_dzasum")]
pub fn dzasum<T: Vector<Complex<f64>>>(x: &T) -> f64 {
pub fn dzasum<T: Vector<Complex<f64>> + ?Sized>(x: &T) -> f64 {
unsafe { ::sys::cblas_dzasum(len(x), as_ptr(x) as *const _, stride(x)) }
}

/// Return the index of the element with maximum absolute value.
#[doc(alias = "cblas_isamax")]
pub fn isamax<T: Vector<f32>>(x: &T) -> usize {
pub fn isamax<T: Vector<f32> + ?Sized>(x: &T) -> usize {
unsafe { ::sys::cblas_isamax(len(x), as_ptr(x), stride(x)) }
}

/// Return the index of the element with maximum absolute value.
#[doc(alias = "cblas_idamax")]
pub fn idamax<T: Vector<f64>>(x: &T) -> usize {
pub fn idamax<T: Vector<f64> + ?Sized>(x: &T) -> usize {
unsafe { ::sys::cblas_idamax(len(x), as_ptr(x), stride(x)) }
}

#[cfg(feature = "complex")]
/// Return the index of the element with maximum modulus.
#[doc(alias = "cblas_icamax")]
pub fn icamax<T: Vector<Complex<f32>>>(x: &T) -> usize {
pub fn icamax<T: Vector<Complex<f32>> + ?Sized>(x: &T) -> usize {
unsafe { ::sys::cblas_icamax(len(x), as_ptr(x) as *const _, stride(x)) }
}

#[cfg(feature = "complex")]
/// Return the index of the element with maximum modulus.
#[doc(alias = "cblas_izamax")]
pub fn izamax<T: Vector<Complex<f64>>>(x: &T) -> usize {
pub fn izamax<T: Vector<Complex<f64>> + ?Sized>(x: &T) -> usize {
unsafe { ::sys::cblas_izamax(len(x), as_ptr(x) as *const _, stride(x)) }
}

/// Swap vectors `x` and `y`.
#[doc(alias = "cblas_sswap")]
pub fn sswap<T: Vector<f32>>(x: &mut T, y: &mut T) {
pub fn sswap<T: Vector<f32> + ?Sized>(x: &mut T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe { ::sys::cblas_sswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) }
}

/// Copy the content of `x` into `y`.
#[doc(alias = "cblas_scopy")]
pub fn scopy<T: Vector<f32>>(x: &T, y: &mut T) {
pub fn scopy<T: Vector<f32> + ?Sized>(x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe { ::sys::cblas_scopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) }
}

/// `y` := `alpha` * `x` + `y`.
#[doc(alias = "cblas_saxpy")]
pub fn saxpy<T: Vector<f32>>(alpha: f32, x: &T, y: &mut T) {
pub fn saxpy<T: Vector<f32> + ?Sized>(alpha: f32, x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_saxpy(
Expand All @@ -308,21 +320,21 @@ pub mod level1 {

/// Swap vectors `x` and `y`.
#[doc(alias = "cblas_dswap")]
pub fn dswap<T: Vector<f64>>(x: &mut T, y: &mut T) {
pub fn dswap<T: Vector<f64> + ?Sized>(x: &mut T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe { sys::cblas_dswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) }
}

/// Copy the content of `x` into `y`.
#[doc(alias = "cblas_dcopy")]
pub fn dcopy<T: Vector<f64>>(x: &T, y: &mut T) {
pub fn dcopy<T: Vector<f64> + ?Sized>(x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe { sys::cblas_dcopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) }
}

/// `y` := `alpha` * `x` + `y`.
#[doc(alias = "cblas_daxpy")]
pub fn daxpy<T: Vector<f64>>(alpha: f64, x: &T, y: &mut T) {
pub fn daxpy<T: Vector<f64> + ?Sized>(alpha: f64, x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
sys::cblas_daxpy(
Expand All @@ -339,7 +351,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// Swap vectors `x` and `y`.
#[doc(alias = "cblas_cswap")]
pub fn cswap<T: Vector<Complex<f32>>>(x: &mut T, y: &mut T) {
pub fn cswap<T: Vector<Complex<f32>> + ?Sized>(x: &mut T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_cswap(
Expand All @@ -355,7 +367,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// Copy the content of `x` into `y`.
#[doc(alias = "cblas_ccopy")]
pub fn ccopy<T: Vector<Complex<f32>>>(x: &T, y: &mut T) {
pub fn ccopy<T: Vector<Complex<f32>> + ?Sized>(x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
sys::cblas_ccopy(
Expand All @@ -371,7 +383,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// `y` := `alpha` * `x` + `y`.
#[doc(alias = "cblas_caxpy")]
pub fn caxpy<T: Vector<Complex<f32>>>(alpha: &Complex<f32>, x: &T, y: &mut T) {
pub fn caxpy<T: Vector<Complex<f32>> + ?Sized>(alpha: &Complex<f32>, x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
sys::cblas_caxpy(
Expand All @@ -388,7 +400,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// Swap vectors `x` and `y`.
#[doc(alias = "cblas_zswap")]
pub fn zswap<T: Vector<Complex<f64>>>(x: &mut T, y: &mut T) {
pub fn zswap<T: Vector<Complex<f64>> + ?Sized>(x: &mut T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_zswap(
Expand All @@ -404,7 +416,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// Copy the content of `x` into `y`.
#[doc(alias = "cblas_zcopy")]
pub fn zcopy<T: Vector<Complex<f64>>>(x: &T, y: &mut T) {
pub fn zcopy<T: Vector<Complex<f64>> + ?Sized>(x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_zcopy(
Expand All @@ -420,7 +432,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// `y` := `alpha` * `x` + `y`.
#[doc(alias = "cblas_zaxpy")]
pub fn zaxpy<T: Vector<Complex<f64>>>(alpha: &Complex<f64>, x: &T, y: &mut T) {
pub fn zaxpy<T: Vector<Complex<f64>> + ?Sized>(alpha: &Complex<f64>, x: &T, y: &mut T) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_zaxpy(
Expand Down Expand Up @@ -544,7 +556,7 @@ pub mod level1 {
///
/// for all indices i.
#[doc(alias = "cblas_srot")]
pub fn srot<T: Vector<f32>>(x: &mut T, y: &mut T, c: f32, s: f32) {
pub fn srot<T: Vector<f32> + ?Sized>(x: &mut T, y: &mut T, c: f32, s: f32) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_srot(
Expand All @@ -566,7 +578,7 @@ pub mod level1 {
///
/// for all indices i.
#[doc(alias = "cblas_srotm")]
pub fn srotm<T: Vector<f32>>(x: &mut T, y: &mut T, h: H<f32>) {
pub fn srotm<T: Vector<f32> + ?Sized>(x: &mut T, y: &mut T, h: H<f32>) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
let p = match h {
H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22],
Expand Down Expand Up @@ -664,7 +676,7 @@ pub mod level1 {
///
/// for all indices i.
#[doc(alias = "cblas_drot")]
pub fn drot<T: Vector<f64>>(x: &mut T, y: &mut T, c: f64, s: f64) {
pub fn drot<T: Vector<f64> + ?Sized>(x: &mut T, y: &mut T, c: f64, s: f64) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
unsafe {
::sys::cblas_drot(
Expand All @@ -686,7 +698,7 @@ pub mod level1 {
///
/// for all indices i.
#[doc(alias = "cblas_drotm")]
pub fn drotm<T: Vector<f64>>(x: &mut T, y: &mut T, h: H<f64>) {
pub fn drotm<T: Vector<f64> + ?Sized>(x: &mut T, y: &mut T, h: H<f64>) {
check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length");
let p = match h {
H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22],
Expand All @@ -708,20 +720,20 @@ pub mod level1 {

/// Multiply each element of `x` by `alpha`.
#[doc(alias = "cblas_sscal")]
pub fn sscal<T: Vector<f32>>(alpha: f32, x: &mut T) {
pub fn sscal<T: Vector<f32> + ?Sized>(alpha: f32, x: &mut T) {
unsafe { ::sys::cblas_sscal(len(x), alpha, as_mut_ptr(x), stride(x)) }
}

/// Multiply each element of `x` by `alpha`.
#[doc(alias = "cblas_dscal")]
pub fn dscal<T: Vector<f64>>(alpha: f64, x: &mut T) {
pub fn dscal<T: Vector<f64> + ?Sized>(alpha: f64, x: &mut T) {
unsafe { ::sys::cblas_dscal(len(x), alpha, as_mut_ptr(x), stride(x)) }
}

#[cfg(feature = "complex")]
/// Multiply each element of `x` by `alpha`.
#[doc(alias = "cblas_cscal")]
pub fn cscal<T: Vector<Complex<f32>>>(alpha: &Complex<f32>, x: &mut T) {
pub fn cscal<T: Vector<Complex<f32>> + ?Sized>(alpha: &Complex<f32>, x: &mut T) {
unsafe {
::sys::cblas_cscal(
len(x),
Expand All @@ -735,7 +747,7 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// Multiply each element of `x` by `alpha`.
#[doc(alias = "cblas_zscal")]
pub fn zscal<T: Vector<Complex<f64>>>(alpha: &Complex<f64>, x: &mut T) {
pub fn zscal<T: Vector<Complex<f64>> + ?Sized>(alpha: &Complex<f64>, x: &mut T) {
unsafe {
::sys::cblas_zscal(
len(x),
Expand All @@ -749,14 +761,14 @@ pub mod level1 {
#[cfg(feature = "complex")]
/// Multiply each element of `x` by `alpha`.
#[doc(alias = "cblas_csscal")]
pub fn csscal<T: Vector<Complex<f32>>>(alpha: f32, x: &mut T) {
pub fn csscal<T: Vector<Complex<f32>> + ?Sized>(alpha: f32, x: &mut T) {
unsafe { ::sys::cblas_csscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) }
}

#[cfg(feature = "complex")]
/// Multiple each element of a matrix/vector by a constant.
#[doc(alias = "cblas_zdscal")]
pub fn zdscal<T: Vector<Complex<f64>>>(alpha: f64, x: &mut T) {
pub fn zdscal<T: Vector<Complex<f64>> + ?Sized>(alpha: f64, x: &mut T) {
unsafe { ::sys::cblas_zdscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) }
}
}
Expand Down
Loading

0 comments on commit e48a6d3

Please sign in to comment.