Skip to content

Commit

Permalink
Fix tri mask ops return docstring (tracel-ai#2517)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Nov 20, 2024
1 parent 76e67bf commit b4e8e45
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions crates/burn-tensor/src/tensor/api/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,22 @@ where
///
/// # Returns
///
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
/// upper triangle taking into account the specified `offset`.
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
/// upper triangle taking into account the specified `offset`. All other elements are `true`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
/// println!("{mask}");
/// // [[false, false, false],
/// // [true, false, false],
/// // [true, true, false]]
/// }
/// ```
pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Upper, offset, device)
}
Expand All @@ -159,8 +173,22 @@ where
///
/// # Returns
///
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
/// lower triangle taking into account the specified `offset`.
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
/// lower triangle taking into account the specified `offset`. All other elements are `true`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
/// println!("{mask}");
/// // [[false, true, true],
/// // [false, false, true],
/// // [false, false, false]]
/// }
/// ```
pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Lower, offset, device)
}
Expand All @@ -177,8 +205,22 @@ where
///
/// # Returns
///
/// Returns a boolean tensor where `true` indicates the elements of the matrix that are part of the
/// diagonal.
/// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
/// diagonal. All other elements are `true`.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Bool};
///
/// fn example<B: Backend>() {
/// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
/// println!("{mask}");
/// // [[false, true, true],
/// // [true, false, true],
/// // [true, true, false]]
/// }
/// ```
pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
Self::tri_mask(shape, TriPart::Diagonal, offset, device)
}
Expand Down

0 comments on commit b4e8e45

Please sign in to comment.