diff --git a/crates/burn-tensor/src/tensor/api/bool.rs b/crates/burn-tensor/src/tensor/api/bool.rs index c0bf5fcdfc..ea7c5b196d 100644 --- a/crates/burn-tensor/src/tensor/api/bool.rs +++ b/crates/burn-tensor/src/tensor/api/bool.rs @@ -139,8 +139,22 @@ where /// /// # Returns /// - /// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the - /// upper triangle taking into account the specified `offset`. + /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the + /// upper triangle taking into account the specified `offset`. All other elements are `true`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let mask = Tensor::::triu_mask([3, 3], 0, &Default::default()); + /// println!("{mask}"); + /// // [[false, false, false], + /// // [true, false, false], + /// // [true, true, false]] + /// } + /// ``` pub fn triu_mask>(shape: S, offset: i64, device: &B::Device) -> Self { Self::tri_mask(shape, TriPart::Upper, offset, device) } @@ -159,8 +173,22 @@ where /// /// # Returns /// - /// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the - /// lower triangle taking into account the specified `offset`. + /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the + /// lower triangle taking into account the specified `offset`. All other elements are `true`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let mask = Tensor::::tril_mask([3, 3], 0, &Default::default()); + /// println!("{mask}"); + /// // [[false, true, true], + /// // [false, false, true], + /// // [false, false, false]] + /// } + /// ``` pub fn tril_mask>(shape: S, offset: i64, device: &B::Device) -> Self { Self::tri_mask(shape, TriPart::Lower, offset, device) } @@ -177,8 +205,22 @@ where /// /// # Returns /// - /// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the - /// diagonal. + /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the + /// diagonal. All other elements are `true`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Bool}; + /// + /// fn example() { + /// let mask = Tensor::::diag_mask([3, 3], 0, &Default::default()); + /// println!("{mask}"); + /// // [[false, true, true], + /// // [true, false, true], + /// // [true, true, false]] + /// } + /// ``` pub fn diag_mask>(shape: S, offset: i64, device: &B::Device) -> Self { Self::tri_mask(shape, TriPart::Diagonal, offset, device) }