Skip to content

Commit

Permalink
Add test int one_hot and change ops docs in the book (#2519)
Browse files Browse the repository at this point in the history
* add int one_hot tests, change tensor creation function formatting in book.

* move swap_dims and transpose to basic ops in book.

---------

Co-authored-by: Tiago Sanona <[email protected]>
  • Loading branch information
tsanona and Tiago Sanona authored Nov 21, 2024
1 parent 8be2032 commit a6c7a2b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
13 changes: 6 additions & 7 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.unsqueeze_dims(dims)` | N/A |
| `tensor.unsqueeze_dims(dims)` | N/A |

### Numeric Operations

Expand All @@ -179,7 +181,6 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `Tensor::eye(size, device)` | `torch.eye(size, device=device)` |
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
| `Tensor::zeros(shape, device)` | `torch.zeros(shape, device=device)` |
| `tensor.abs()` | `torch.abs(tensor)` |
| `tensor.add(other)` or `tensor + other` | `tensor + other` |
Expand Down Expand Up @@ -255,7 +256,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
Those operations are only available for `Float` tensors.

| Burn API | PyTorch Equivalent |
|-----------------------------------------------| ---------------------------------- |
| --------------------------------------------- | ---------------------------------- |
| `Tensor::one_hot(index, num_classes, device)` | N/A |
| `tensor.cast(dtype)` | `tensor.to(dtype)` |
| `tensor.ceil()` | `tensor.ceil()` |
Expand All @@ -275,10 +276,8 @@ Those operations are only available for `Float` tensors.
| `tensor.round()` | `tensor.round()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
Expand All @@ -290,8 +289,8 @@ Those operations are only available for `Int` tensors.

| Burn API | PyTorch Equivalent |
| ------------------------------------------------ | ------------------------------------------------------- |
| `tensor.arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
| `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
Expand Down
40 changes: 37 additions & 3 deletions crates/burn-tensor/src/tests/ops/one_hot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{Int, TensorData};

#[test]
fn should_support_one_hot() {
fn float_should_support_one_hot() {
let device = Default::default();

let tensor = TestTensor::<1>::one_hot(0, 5, &device);
Expand All @@ -26,15 +26,49 @@ mod tests {

#[test]
#[should_panic]
fn should_panic_when_index_exceeds_number_of_classes() {
fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() {
let device = Default::default();
let tensor = TestTensor::<1>::one_hot(1, 1, &device);
}

#[test]
#[should_panic]
fn should_panic_when_number_of_classes_is_zero() {
fn float_one_hot_should_panic_when_number_of_classes_is_zero() {
let device = Default::default();
let tensor = TestTensor::<1>::one_hot(0, 0, &device);
}

#[test]
fn int_should_support_one_hot() {
let device = Default::default();

let index_tensor = TestTensorInt::<1>::arange(0..5, &device);
let one_hot_tensor = index_tensor.one_hot(5);
let expected = TestTensorInt::eye(5, &device).into_data();
one_hot_tensor.into_data().assert_eq(&expected, false);
}

#[test]
#[should_panic]
fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() {
let device = Default::default();
let index_tensor = TestTensorInt::<1>::arange(0..6, &device);
let one_hot_tensor = index_tensor.one_hot(5);
}

#[test]
#[should_panic]
fn int_one_hot_should_panic_when_number_of_classes_is_zero() {
let device = Default::default();
let index_tensor = TestTensorInt::<1>::arange(0..3, &device);
let one_hot_tensor = index_tensor.one_hot(0);
}

#[test]
#[should_panic]
fn int_one_hot_should_panic_when_number_of_classes_is_1() {
let device = Default::default();
let index_tensor = TestTensorInt::<1>::arange(0..3, &device);
let one_hot_tensor = index_tensor.one_hot(1);
}
}

0 comments on commit a6c7a2b

Please sign in to comment.