From a6c7a2b0c4159edb549108d86e99dbe57313338e Mon Sep 17 00:00:00 2001 From: Tiago Sanona <40792244+tsanona@users.noreply.github.com> Date: Thu, 21 Nov 2024 17:51:48 +0100 Subject: [PATCH] Add test int one_hot and change ops docs in the book (#2519) * add int one_hot tests, change tensor creation function formatting in book. * move swap_dims and transpose to basic ops in book. --------- Co-authored-by: Tiago Sanona --- burn-book/src/building-blocks/tensor.md | 13 ++++--- crates/burn-tensor/src/tests/ops/one_hot.rs | 40 +++++++++++++++++++-- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index f5fba9eb1d..493fae4250 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -164,11 +164,13 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.slice(ranges)` | `tensor[(*ranges,)]` | | `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | | `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | +| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | | `tensor.to_data()` | N/A | | `tensor.to_device(device)` | `tensor.to(device)` | +| `tensor.transpose()` | `tensor.T` | | `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | | `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | -| `tensor.unsqueeze_dims(dims)` | N/A | +| `tensor.unsqueeze_dims(dims)` | N/A | ### Numeric Operations @@ -179,7 +181,6 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `Tensor::eye(size, device)` | `torch.eye(size, device=device)` | | `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` | | `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` | -| `Tensor::zeros(shape)` | `torch.zeros(shape)` | | `Tensor::zeros(shape, device)` | `torch.zeros(shape, device=device)` | | `tensor.abs()` | `torch.abs(tensor)` | | `tensor.add(other)` or `tensor + other` | `tensor + other` | @@ -255,7 +256,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | -|-----------------------------------------------| ---------------------------------- | +| --------------------------------------------- | ---------------------------------- | | `Tensor::one_hot(index, num_classes, device)` | N/A | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | @@ -275,10 +276,8 @@ Those operations are only available for `Float` tensors. | `tensor.round()` | `tensor.round()` | | `tensor.sin()` | `tensor.sin()` | | `tensor.sqrt()` | `tensor.sqrt()` | -| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | | `tensor.tanh()` | `tensor.tanh()` | | `tensor.to_full_precision()` | `tensor.to(torch.float)` | -| `tensor.transpose()` | `tensor.T` | | `tensor.var(dim)` | `tensor.var(dim)` | | `tensor.var_bias(dim)` | N/A | | `tensor.var_mean(dim)` | N/A | @@ -290,8 +289,8 @@ Those operations are only available for `Int` tensors. | Burn API | PyTorch Equivalent | | ------------------------------------------------ | ------------------------------------------------------- | -| `tensor.arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | -| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | +| `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | +| `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index ca7ce62611..310399119f 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{Int, TensorData}; #[test] - fn should_support_one_hot() { + fn float_should_support_one_hot() { let device = Default::default(); let tensor = TestTensor::<1>::one_hot(0, 5, &device); @@ -26,15 +26,49 @@ mod tests { #[test] #[should_panic] - fn should_panic_when_index_exceeds_number_of_classes() { + fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { let device = Default::default(); let tensor = TestTensor::<1>::one_hot(1, 1, &device); } #[test] #[should_panic] - fn should_panic_when_number_of_classes_is_zero() { + fn float_one_hot_should_panic_when_number_of_classes_is_zero() { let device = Default::default(); let tensor = TestTensor::<1>::one_hot(0, 0, &device); } + + #[test] + fn int_should_support_one_hot() { + let device = Default::default(); + + let index_tensor = TestTensorInt::<1>::arange(0..5, &device); + let one_hot_tensor = index_tensor.one_hot(5); + let expected = TestTensorInt::eye(5, &device).into_data(); + one_hot_tensor.into_data().assert_eq(&expected, false); + } + + #[test] + #[should_panic] + fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { + let device = Default::default(); + let index_tensor = TestTensorInt::<1>::arange(0..6, &device); + let one_hot_tensor = index_tensor.one_hot(5); + } + + #[test] + #[should_panic] + fn int_one_hot_should_panic_when_number_of_classes_is_zero() { + let device = Default::default(); + let index_tensor = TestTensorInt::<1>::arange(0..3, &device); + let one_hot_tensor = index_tensor.one_hot(0); + } + + #[test] + #[should_panic] + fn int_one_hot_should_panic_when_number_of_classes_is_1() { + let device = Default::default(); + let index_tensor = TestTensorInt::<1>::arange(0..3, &device); + let one_hot_tensor = index_tensor.one_hot(1); + } }