diff --git a/crates/burn-remote/src/server/processor.rs b/crates/burn-remote/src/server/processor.rs index ee0cd52e65..1328827748 100644 --- a/crates/burn-remote/src/server/processor.rs +++ b/crates/burn-remote/src/server/processor.rs @@ -9,7 +9,7 @@ use std::sync::mpsc::Sender; use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent}; -/// The goal of the processor is to asynchonously process compute tasks on it own thread. +/// The goal of the processor is to asynchronously process compute tasks on it own thread. pub struct Processor { p: PhantomData, } diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 92674f8480..8ebe4edfe8 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -20,6 +20,51 @@ use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, Tensor use crate::{DType, Element, TensorPrimitive}; /// A tensor with a given backend, shape and data type. +/// +/// # Indexing +/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types +/// or [`select`](Tensor::select) for numeric types. +/// +/// ## Example +/// +/// ```rust +/// use burn_tensor::backend::Backend; +/// use burn_tensor::Tensor; +/// use burn_tensor::Int; +/// +/// fn example() { +/// let device = Default::default(); +/// +/// let tensor = Tensor::::from_data( +/// [ +/// [3.0, 4.9, 2.0], +/// [2.0, 1.9, 3.0], +/// [6.0, 1.5, 7.0], +/// [3.0, 4.9, 9.0], +/// ], +/// &device, +/// ); +/// +/// // Slice the tensor to get the second and third rows: +/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]] +/// // The resulting tensor will have dimensions [2, 3]. +/// let slice = tensor.clone().slice([1..3]); +/// println!("{slice:?}"); +/// +/// // Slice the tensor to get the first two rows and the first 2 columns: +/// // [[3.0, 4.9], [2.0, 1.9]] +/// // The resulting tensor will have dimensions [2, 2]. +/// let slice = tensor.clone().slice([0..2, 0..2]); +/// println!("{slice:?}"); +/// +/// // Index the tensor along the dimension 1 to get the elements 0 and 2: +/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]] +/// // The resulting tensor will have dimensions [4, 2] +/// let indices = Tensor::::from_data([0, 2], &device); +/// let indexed = tensor.select(1, indices); +/// println!("{indexed:?}"); +/// } +/// ``` #[derive(new, Clone, Debug)] pub struct Tensor where @@ -56,6 +101,23 @@ where } /// Create an empty tensor of the given shape. + /// + /// # Arguments + /// + /// - shape: The shape of the tensor. + /// - device: The device where the tensor will be created. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create an empty tensor with dimensions [2, 3, 4]. + /// let tensor = Tensor::::empty([2, 3, 4], &device); + /// } + /// ``` pub fn empty>(shape: S, device: &B::Device) -> Self { let shape = shape.into(); check!(TensorCheck::creation_ops::("Empty", &shape.dims)); @@ -64,12 +126,36 @@ where /// Returns the dimensions of the current tensor. /// - /// Equivalent to `tensor.shape().dims`. + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::ones([2, 3, 4], &device); + /// let dims = tensor.dims(); // [2, 3, 4] + /// println!("{dims:?}"); + /// } + /// ``` pub fn dims(&self) -> [usize; D] { Self::shape(self).dims() } /// Returns the shape of the current tensor. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::ones([2, 3, 4], &device); + /// // Shape { dims: [2, 3, 4] } + /// let shape = tensor.shape(); + /// } + /// ``` pub fn shape(&self) -> Shape { K::shape(&self.primitive) } @@ -94,17 +180,18 @@ where /// - If the shape does not match the number of elements of the original shape. /// /// # Example + /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); + /// // Create a tensor with dimensions [2, 3, 4] /// let tensor = Tensor::::ones([2, 3, 4], &device); - /// // Given a 3D tensor with dimensions (2, 3, 4), reshape it to (2, 12) - /// let reshaped_tensor: Tensor:: = tensor.reshape([2, -1]); - /// // The resulting tensor will have dimensions (2, 12). - /// println!("{:?}", reshaped_tensor.shape()); + /// // Reshape it to [2, 12], where 12 is inferred from the number of elements. + /// let reshaped = tensor.reshape([2, -1]); + /// println!("{reshaped:?}"); /// } /// ``` pub fn reshape>(self, shape: S) -> Tensor { @@ -122,6 +209,25 @@ where /// # Returns /// /// The transposed tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor of shape [2, 3] + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// + /// // Transpose the tensor: + /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]] + /// // The resulting tensor will have dimensions [3, 2]. + /// let transposed = tensor.transpose(); + /// println!("{transposed:?}"); + /// } + /// ``` pub fn transpose(self) -> Tensor { Tensor::new(K::transpose(self.primitive)) } @@ -137,6 +243,25 @@ where /// # Returns /// /// The tensor with the dimensions swapped. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor of shape [2, 3] + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// + /// // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`): + /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]] + /// // The resulting tensor will have dimensions [3, 2]. + /// let swapped = tensor.swap_dims(0, 1); + /// println!("{swapped:?}"); + /// } + /// ``` pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { check!(TensorCheck::swap_dims::(dim1, dim2)); Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) @@ -154,6 +279,25 @@ where /// # Returns /// /// The tensor with the dimensions permuted. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor of shape [3, 2] + /// let tensor = Tensor::::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device); + /// + /// // Permute the dimensions 1 and 0: + /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]] + /// // The resulting tensor will have dimensions [3, 2]. + /// let permuted = tensor.permute([1, 0]); + /// println!("{permuted:?}"); + /// } + /// ``` pub fn permute(self, axes: [isize; D]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; D] = [0; D]; @@ -192,7 +336,26 @@ where /// # Returns /// /// The tensor with the dimensions moved. - // This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 3D tensor of shape [3, 2, 1] + /// let tensor = Tensor::::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device); + /// + /// // Move the dimensions 0 and 1: + /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]] + /// // The resulting tensor will have dimensions [2, 3, 1]. + /// let moved = tensor.movedim(1, 0); + /// println!("{moved:?}"); + /// } + /// ``` + // This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op // for it pub fn movedim(self, src: S1, dst: S2) -> Tensor { let source_dims = src.into_dim_vec::(); @@ -235,6 +398,36 @@ where /// # Returns /// /// The tensor with the axes flipped. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [4, 3] + /// let tensor = Tensor::::from_data( + /// [ + /// [3.0, 4.9, 2.0], + /// [2.0, 1.9, 3.0], + /// [4.0, 5.9, 8.0], + /// [1.4, 5.8, 6.0], + /// ], + /// &device, + /// ); + /// + /// // Flip the elements in dimensions 0 and 1: + /// // [[6.0, 5.8, 1.4], + /// // [8.0, 5.9, 4.0], + /// // [3.0, 1.9, 2.0], + /// // [2.0, 4.9, 3.0]] + /// // The resulting tensor will have dimensions [4, 3]. + /// let flipped = tensor.flip([0, 1]); + /// println!("{flipped:?}"); + /// } + /// ``` pub fn flip(self, axes: [isize; N]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; N] = [0; N]; @@ -279,15 +472,14 @@ where /// /// fn example() { /// let device = Default::default(); + /// // Create a 3D tensor with dimensions [2, 3, 4] /// let tensor = Tensor::::ones(Shape::new([2, 3, 4]), &device); /// - /// // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2: - /// let flattened_tensor: Tensor:: = tensor.flatten(1, 2); - /// - /// // The resulting tensor will have dimensions (2, 12). - /// println!("{:?}", flattened_tensor.shape()); + /// // Flatten the tensor from dimensions 1 to 2 (inclusive). + /// // The resulting tensor will have dimensions [2, 12] + /// let flattened: Tensor = tensor.flatten(1, 2); + /// println!("{flattened:?}"); /// } - /// /// ``` pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { check!(TensorCheck::flatten::(start_dim, end_dim)); @@ -331,13 +523,16 @@ where /// /// fn example() { /// let device = Default::default(); - /// let tensor = Tensor::::ones(Shape::new([2, 1, 4]), &device); - /// - /// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1 - /// let squeezed_tensor: Tensor:: = tensor.squeeze(1); - /// - /// // Resulting tensor will have dimensions (2, 4) - /// println!("{:?}", squeezed_tensor.shape()); + /// // Create a 3D tensor with dimensions [3, 1, 3] + /// let tensor = Tensor::::from_data( + /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]], + /// &device, + /// ); + /// + /// // Squeeze the dimension 1. + /// // The resulting tensor will have dimensions [3, 3]. + /// let squeezed = tensor.squeeze::<2>(1); + /// println!("{squeezed:?}"); /// } /// ``` pub fn squeeze(self, dim: usize) -> Tensor { @@ -380,13 +575,13 @@ where /// /// fn example() { /// let device = Default::default(); + /// // Create a 4D tensor with dimensions [2, 1, 4, 1] /// let tensor = Tensor::::ones(Shape::new([2, 1, 4, 1]), &device); /// - /// // Given a 4D tensor with dimensions (2, 1, 4, 1), squeeze the 1 and 3 dimensions - /// let squeezed_tensor: Tensor:: = tensor.squeeze_dims(&[1, 3]); - /// - /// // Resulting tensor will have dimensions (2, 4) - /// println!("{:?}", squeezed_tensor.shape()); + /// // Squeeze the dimensions 1 and 3. + /// // The resulting tensor will have dimensions [2, 4]. + /// let squeezed: Tensor = tensor.squeeze_dims(&[1, 3]); + /// println!("{squeezed:?}"); /// } /// ``` pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { @@ -453,10 +648,12 @@ where /// /// fn example() { /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [3, 3] /// let tensor = Tensor::::ones(Shape::new([3, 3]), &device); - /// let tensor = tensor.unsqueeze::<4>(); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [1, 1, 3, 3] } + /// // Unsqueeze the tensor up to 4 dimensions. + /// // The resulting tensor will have dimensions [1, 1, 3, 3]. + /// let unsqueezed = tensor.unsqueeze::<4>(); + /// println!("{unsqueezed:?}"); /// } /// ``` pub fn unsqueeze(self) -> Tensor { @@ -482,10 +679,12 @@ where /// /// fn example() { /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [3, 3] /// let tensor = Tensor::::ones(Shape::new([3, 3]), &device); - /// let tensor: Tensor = tensor.unsqueeze_dim(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [3, 1, 3] } + /// // Unsqueeze the dimension 1. + /// // The resulting tensor will have dimensions [3, 1, 3]. + /// let unsqueezed: Tensor = tensor.unsqueeze_dim(1); + /// println!("{unsqueezed:?}"); /// } /// ``` pub fn unsqueeze_dim(self, dim: usize) -> Tensor { @@ -519,10 +718,12 @@ where /// /// fn example() { /// let device = Default::default(); + /// // Create a 3D tensor with dimensions [3, 4, 5] /// let tensor = Tensor::::ones(Shape::new([3, 4, 5]), &device); - /// let tensor: Tensor = tensor.unsqueeze_dims(&[0, -1, -1]); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [1, 3, 4, 5, 1, 1] } + /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice. + /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1]. + /// let unsqueezed: Tensor = tensor.unsqueeze_dims(&[0, -1, -1]); + /// println!("{unsqueezed:?}"); /// } /// ``` pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { @@ -727,6 +928,34 @@ where } /// Repeat the tensor along the given dimension. + /// + /// + /// # Arguments + /// - `dim`: The dimension to repeat. + /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor. + /// + /// # Returns + /// + /// A new tensor with the given dimension repeated `times` times. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [3, 2] + /// let tensor = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); + /// + /// // Repeat the tensor along the dimension 0 twice. + /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]] + /// // The resulting tensor will have dimensions [6, 2]. + /// let repeated = tensor.repeat_dim(0, 2); + /// println!("{repeated:?}"); + /// } + /// ``` pub fn repeat_dim(self, dim: usize, times: usize) -> Self { Self::new(K::repeat_dim(self.primitive, dim, times)) } @@ -734,6 +963,29 @@ where /// Repeat the tensor along the given dimensions. /// # Arguments /// - `sizes`: Borrowed slice of the number of times to repeat each dimension. + /// + /// # Returns + /// + /// A new tensor with the given dimensions repeated `times` times. + /// + /// # Example + /// + /// ```rust + /// + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [3, 2] + /// let tensor = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); + /// + /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once. + /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]] + /// // The resulting tensor will have dimensions [6, 2]. + /// let repeated = tensor.repeat(&[2, 1]); + /// } + /// ``` pub fn repeat(self, sizes: &[usize]) -> Self { let mut tensor = self; for (dim, ×) in sizes.iter().enumerate() { @@ -749,6 +1001,23 @@ where /// # Panics /// /// If the two tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let t1 = Tensor::::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); + /// let t2 = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); + /// // Compare the elements of the two 2D tensors with dimensions [3, 2]. + /// // [[false, true], [true, true], [true, true]] + /// let equal = t1.equal(t2); + /// println!("{equal:?}"); + /// } + /// ``` pub fn equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); Tensor::new(K::equal(self.primitive, other.primitive)) @@ -759,6 +1028,23 @@ where /// # Panics /// /// If the two tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let t1 = Tensor::::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); + /// let t2 = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); + /// // Compare the elements of the two 2D tensors for inequality. + /// // [[true, false], [false, false], [false, false]] + /// let not_equal = t1.not_equal(t2); + /// println!("{not_equal:?}"); + /// } + /// ``` pub fn not_equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); Tensor::new(K::not_equal(self.primitive, other.primitive)) @@ -769,6 +1055,25 @@ where /// # Panics /// /// If all tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let t1 = Tensor::::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device); + /// let t2 = Tensor::::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device); + /// + /// // Concatenate the two tensors with shape [2, 3] along the dimension 1. + /// // [[3.0, 4.9, 2.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.4, 5.8, 6.0]] + /// // The resulting tensor will have shape [2, 6]. + /// let concat = Tensor::cat(vec![t1, t2], 1); + /// println!("{concat:?}"); + /// } + /// ``` pub fn cat(tensors: Vec, dim: usize) -> Self { check!(TensorCheck::cat(&tensors, dim)); @@ -784,6 +1089,28 @@ where /// /// If all tensors don't have the same shape. /// Given dimension is not with range of 0..D2 + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let t1 = Tensor::::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device); + /// let t2 = Tensor::::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device); + /// let t3 = Tensor::::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device); + /// + /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0. + /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], + /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], + /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]] + /// // The resulting tensor will have shape [3, 2, 3]. + /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0); + /// println!("{stacked:?}"); + /// } + /// ``` pub fn stack(tensors: Vec>, dim: usize) -> Tensor { check!(TensorCheck::stack::(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); @@ -799,6 +1126,24 @@ where /// # Returns /// /// A tensor iterator. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device); + /// // Given a 2D tensor with dimensions (2, 3), iterate over slices of tensors along the dimension 0. + /// let iter = tensor.iter_dim(0); + /// for (i,tensor) in iter.enumerate() { + /// println!("Tensor {}: {:?}", i, tensor); + /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... } + /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... } + /// } + /// } + /// ``` pub fn iter_dim(self, dim: usize) -> DimIter { check!(TensorCheck::dim_ops::("iter_dim", dim)); DimIter::new(self, dim) @@ -814,6 +1159,32 @@ where /// # Returns /// /// A new tensor with the given dimension narrowed to the given range. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [4, 3] + /// let tensor = Tensor::::from_data( + /// [ + /// [3.0, 4.9, 2.0], + /// [2.0, 1.9, 3.0], + /// [6.0, 1.5, 7.0], + /// [3.0, 4.9, 9.0], + /// ], + /// &device, + /// ); + /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1. + /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]] + /// // The resulting tensor will have dimensions [3, 3]. + /// let narrowed = tensor.narrow(0, 1, 3); + /// println!("{narrowed:?}"); + /// } + /// ``` pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self { check!(TensorCheck::dim_ops::("narrow", dim)); check!(TensorCheck::narrow(&self, dim, start, length)); @@ -832,6 +1203,34 @@ where /// /// # Returns /// A vector of tensors. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [4, 3] + /// let tensor = Tensor::::from_data( + /// [ + /// [3.0, 4.9, 2.0], + /// [2.0, 1.9, 3.0], + /// [6.0, 1.5, 7.0], + /// [3.0, 4.9, 9.0], + /// ], + /// &device, + /// ); + /// // Split the tensor along the dimension 1 into 2 chunks. + /// // The first chuck will have shape [4, 2]: + /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]] + /// // The second chunk will have shape [4, 1]: + /// // [[2.0], [3.0], [7.0], [9.0]] + /// let chunks = tensor.chunk(2, 1); + /// println!("{chunks:?}"); + /// } + /// ``` pub fn chunk(self, chunks: usize, dim: usize) -> Vec { check!(TensorCheck::dim_ops::("chunk", dim)); K::chunk(self.primitive, chunks, dim) @@ -850,6 +1249,29 @@ where /// /// A boolean tensor `Tensor` containing a single element, True if any element in the input tensor /// evaluates to True, False otherwise. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::from_data([[true,false,true],[false,true,false]], &device); + /// let tensor_two = Tensor::::from_data([[false,false,false],[false,false,false]], &device); + /// + /// // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True. + /// let any_tensor = tensor.any(); + /// println!("{:?}", any_tensor); + /// // Tensor { data: [true], ... } + /// + /// // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True. + /// let any_tensor_two = tensor_two.any(); + /// println!("{:?}", any_tensor_two); + /// // Tensor { data: [false], ... } + /// } + /// ``` pub fn any(self) -> Tensor { Tensor::new(K::any(self.primitive)) } @@ -866,6 +1288,23 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = + /// Tensor::::from_data([[true, false, false], [false, true, false]], &device); + /// // Check if any element in the tensor evaluates to True along the dimension 1. + /// // [[true], [true]], + /// let any_dim = tensor.clone().any_dim(1); + /// println!("{any_dim:?}"); + /// } + /// ``` pub fn any_dim(self, dim: usize) -> Tensor { Tensor::new(K::any_dim(self.primitive, dim)) } @@ -880,6 +1319,23 @@ where /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = + /// Tensor::::from_data([[true, false, true], [true, true, true]], &device); + /// // Check if all elements in the tensor evaluate to True (which is not the case). + /// // [false] + /// let all = tensor.all(); + /// println!("{all:?}"); + /// } + /// ``` pub fn all(self) -> Tensor { Tensor::new(K::all(self.primitive)) } @@ -896,6 +1352,23 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = + /// Tensor::::from_data([[true, true, false], [true, true, true]], &device); + /// // Check if all elements in the tensor evaluate to True along the dimension 1. + /// // [[true, true, false]] + /// let all_dim = tensor.clone().all_dim(0); + /// println!("{all_dim:?}"); + /// } + /// ``` pub fn all_dim(self, dim: usize) -> Tensor { Tensor::new(K::all_dim(self.primitive, dim)) } @@ -906,6 +1379,25 @@ where /// /// If the tensor doesn't have one element. /// If the backend fails to read the tensor data synchronously. + /// + /// # Returns + /// + /// The scalar value of the tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::from_data([[3.0]], &device); + /// // Convert the tensor with a single element into a scalar. + /// let scalar = tensor.into_scalar(); + /// println!("{scalar:?}"); + /// } + /// ``` pub fn into_scalar(self) -> K::Elem { crate::try_read_sync(self.into_scalar_async()).expect( "Failed to read tensor data synchronously. This can happen on platforms @@ -940,6 +1432,23 @@ where /// # Returns /// /// A new tensor with the given shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 2D tensor with dimensions [3, 1] + /// let tensor = Tensor::::from_data([[1.], [2.], [3.]], &device); + /// // Expand the tensor to a new shape [3, 4] + /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]] + /// let expanded = tensor.expand([3, 4]); + /// println!("{:?}", expanded); + /// } + /// ``` pub fn expand>(self, shape: S) -> Tensor { let shape = shape.into_shape(&self.shape()); check!(TensorCheck::expand::( diff --git a/crates/burn-tensor/src/tensor/api/sort.rs b/crates/burn-tensor/src/tensor/api/sort.rs index f2ef87d4ea..de66ab50f5 100644 --- a/crates/burn-tensor/src/tensor/api/sort.rs +++ b/crates/burn-tensor/src/tensor/api/sort.rs @@ -37,7 +37,7 @@ where >::Elem: Element, { let device = K::device(&tensor); - let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); + let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); sort_data::(data, dim, &device, descending) } @@ -92,7 +92,7 @@ where >::Elem: Element, { let device = K::device(&tensor); - let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); + let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); sort_data_with_indices::(data, dim, &device, descending) } @@ -188,7 +188,7 @@ where >::Elem: Element, { let device = K::device(&tensor); - let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); + let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); argsort_data::(data, dim, &device, descending) }