Skip to content

Commit

Permalink
Add support for dynamic rank, see #4
Browse files Browse the repository at this point in the history
  • Loading branch information
fre-hu committed Oct 12, 2024
1 parent f0a3ac0 commit b93aaa5
Show file tree
Hide file tree
Showing 22 changed files with 1,195 additions and 837 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ Here are the main features of mdarray:
- Standard Rust mechanisms are used for e.g. indexing and iteration.
- Generic expressions for multidimensional iteration.

The design is inspired from other Rust crates (ndarray, nalgebra, bitvec
and dfdx), the proposed C++ mdarray and mdspan types, and multidimensional
The design is inspired from other Rust crates (ndarray, nalgebra, bitvec, dfdx
and candle), the proposed C++ mdarray and mdspan types, and multidimensional
arrays in other languages.

## License
Expand Down
24 changes: 11 additions & 13 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::expression::Expression;
use crate::index::SliceIndex;
use crate::iter::Iter;
use crate::layout::{Dense, Layout};
use crate::shape::{ConstShape, IntoShape, Shape};
use crate::shape::{ConstShape, Shape};
use crate::slice::Slice;
use crate::tensor::Tensor;
use crate::traits::{Apply, FromExpression, IntoExpression};
Expand All @@ -25,16 +25,16 @@ pub struct Array<T, S: ConstShape>(pub S::Inner<T>);

impl<T, S: ConstShape> Array<T, S> {
/// Creates an array from the given element.
pub fn from_elem<I: IntoShape<IntoShape = S>>(shape: I, elem: T) -> Self
pub fn from_elem(elem: T) -> Self
where
T: Clone,
{
Self::from_expr(expr::from_elem(shape, elem))
Self::from_expr(expr::from_elem(S::default(), elem))
}

/// Creates an array with the results from the given function.
pub fn from_fn<I: IntoShape<IntoShape = S>, F: FnMut(S::Dims) -> T>(shape: I, f: F) -> Self {
Self::from_expr(expr::from_fn(shape, f))
pub fn from_fn<F: FnMut(&[usize]) -> T>(f: F) -> Self {
Self::from_expr(expr::from_fn(S::default(), f))
}

/// Converts an array with a single element into the contained value.
Expand All @@ -45,19 +45,16 @@ impl<T, S: ConstShape> Array<T, S> {
pub fn into_scalar(self) -> T {
assert!(self.len() == 1, "invalid length");

self.into_shape(()).0
self.into_shape::<()>().0
}

/// Converts the array into a reshaped array, which must have the same length.
///
/// # Panics
///
/// Panics if the array length is changed.
pub fn into_shape<I>(self, shape: I) -> Array<T, I::IntoShape>
where
I: IntoShape<IntoShape: ConstShape>,
{
assert!(shape.into_shape().len() == self.len(), "length must not change");
pub fn into_shape<I: ConstShape>(self) -> Array<T, I> {
assert!(I::default().len() == self.len(), "length must not change");

let me = ManuallyDrop::new(self);

Expand All @@ -75,7 +72,7 @@ impl<T, S: ConstShape> Array<T, S> {
index: usize,
}

impl<'a, T, S: ConstShape> Drop for DropGuard<'a, T, S> {
impl<T, S: ConstShape> Drop for DropGuard<'_, T, S> {
fn drop(&mut self) {
let ptr = self.array.as_mut_ptr() as *mut T;

Expand All @@ -85,7 +82,8 @@ impl<T, S: ConstShape> Array<T, S> {
}
}

assert!(expr.dims()[..] == S::default().dims()[..], "invalid shape");
// Ensure that the shape is valid.
_ = expr.shape().with_dims(|dims| S::from_dims(dims));

let mut array = MaybeUninit::uninit();
let mut guard = DropGuard { array: &mut array, index: 0 };
Expand Down
70 changes: 28 additions & 42 deletions src/dim.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::fmt::{Debug, Formatter, Result};

use std::ops::{
Bound, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
use std::hash::Hash;

/// Array dimension trait.
pub trait Dim: Copy + Debug + Default + Send + Sync {
pub trait Dim: Copy + Debug + Default + Eq + Hash + Send + Sync {
/// Merge dimensions, where constant size is preferred over dynamic.
type Merge<D: Dim>: Dim;

Expand All @@ -23,50 +20,28 @@ pub trait Dim: Copy + Debug + Default + Send + Sync {
fn size(self) -> usize;
}

/// Array dimensions trait.
pub trait Dims:
Copy
+ Debug
+ Default
+ IndexMut<(Bound<usize>, Bound<usize>), Output = [usize]>
+ IndexMut<usize, Output = usize>
+ IndexMut<Range<usize>, Output = [usize]>
+ IndexMut<RangeFrom<usize>, Output = [usize]>
+ IndexMut<RangeFull, Output = [usize]>
+ IndexMut<RangeInclusive<usize>, Output = [usize]>
+ IndexMut<RangeTo<usize>, Output = [usize]>
+ IndexMut<RangeToInclusive<usize>, Output = [usize]>
+ Send
+ Sync
+ for<'a> TryFrom<&'a [usize], Error: Debug>
{
}

/// Array strides trait.
pub trait Strides:
Copy
#[allow(unreachable_pub)]
pub trait Dims<T: Copy + Debug + Default + Eq + Hash + Send + Sync>:
AsMut<[T]>
+ AsRef<[T]>
+ Clone
+ Debug
+ Default
+ IndexMut<(Bound<usize>, Bound<usize>), Output = [isize]>
+ IndexMut<usize, Output = isize>
+ IndexMut<Range<usize>, Output = [isize]>
+ IndexMut<RangeFrom<usize>, Output = [isize]>
+ IndexMut<RangeFull, Output = [isize]>
+ IndexMut<RangeInclusive<usize>, Output = [isize]>
+ IndexMut<RangeTo<usize>, Output = [isize]>
+ IndexMut<RangeToInclusive<usize>, Output = [isize]>
+ Eq
+ Hash
+ Send
+ Sync
+ for<'a> TryFrom<&'a [isize], Error: Debug>
+ for<'a> TryFrom<&'a [T], Error: Debug>
{
fn new(len: usize) -> Self;
}

/// Type-level constant.
#[derive(Clone, Copy, Default)]
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)]
pub struct Const<const N: usize>;

/// Dynamically-sized dimension type.
#[derive(Clone, Copy, Debug, Default)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct Dyn(pub usize);

impl<const N: usize> Debug for Const<N> {
Expand Down Expand Up @@ -105,13 +80,24 @@ impl Dim for Dyn {
}
}

macro_rules! impl_dims_strides {
macro_rules! impl_dims {
($($n:tt),+) => {
$(
impl Dims for [usize; $n] {}
impl Strides for [isize; $n] {}
impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for [T; $n] {
fn new(len: usize) -> Self {
assert!(len == $n, "invalid length");

Self::default()
}
}
)+
};
}

impl_dims_strides!(0, 1, 2, 3, 4, 5, 6);
impl_dims!(0, 1, 2, 3, 4, 5, 6);

impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for Box<[T]> {
fn new(len: usize) -> Self {
vec![T::default(); len].into()
}
}
Loading

0 comments on commit b93aaa5

Please sign in to comment.