From 9a2b8416655ab55d7c1dbcd16a05108a9ee98aac Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 21 Nov 2024 13:19:33 -0500 Subject: [PATCH] Add float cast op for JIT backend (#2511) * Fix backend FloatElem docstring * Remove elem type generic from JitTensor (float ops still using default dtype) * Use execute_with_dtype macro for float ops * Add cast for fusion * Add warning for type promotion * Add cast to backend router * Add Primitive trait to correctly display the tensor dtype * Fix primtiive associated type * Fix jit bool tensor display * Fix CI * Fix clippy * Whoops * Fix candle bool tensor display * Add shape to primitive trait * Add primitive tests * Rename trait to TensorMetadata * Add missing changes to example * Fix split from merge --- .../backend-extension/custom-cubecl-kernel.md | 28 +- .../backend-extension/custom-wgpu-kernel.md | 21 +- crates/burn-autodiff/src/grads.rs | 4 +- crates/burn-autodiff/src/ops/base.rs | 4 +- crates/burn-autodiff/src/ops/bool_tensor.rs | 4 - crates/burn-autodiff/src/ops/int_tensor.rs | 4 - crates/burn-autodiff/src/ops/tensor.rs | 99 ++-- crates/burn-autodiff/src/tensor.rs | 12 +- crates/burn-candle/src/ops/base.rs | 2 +- crates/burn-candle/src/ops/bool_tensor.rs | 6 +- crates/burn-candle/src/ops/int_tensor.rs | 4 - crates/burn-candle/src/ops/tensor.rs | 4 - crates/burn-candle/src/tensor.rs | 36 +- crates/burn-fusion/src/ops/binary.rs | 54 +- crates/burn-fusion/src/ops/boolean.rs | 8 +- crates/burn-fusion/src/ops/float.rs | 414 ++++++------- crates/burn-fusion/src/ops/int.rs | 8 +- crates/burn-fusion/src/ops/qtensor.rs | 4 +- crates/burn-fusion/src/tensor.rs | 25 +- crates/burn-jit/src/backend.rs | 8 +- crates/burn-jit/src/bridge.rs | 18 +- crates/burn-jit/src/fusion/base.rs | 17 +- crates/burn-jit/src/kernel/binary.rs | 32 +- crates/burn-jit/src/kernel/cast/base.rs | 17 +- crates/burn-jit/src/kernel/cast/bool_cast.rs | 7 +- crates/burn-jit/src/kernel/clamp.rs | 4 +- crates/burn-jit/src/kernel/comparison.rs | 101 ++-- crates/burn-jit/src/kernel/contiguous.rs | 25 +- .../burn-jit/src/kernel/conv/conv2d/base.rs | 18 +- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 36 +- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 20 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 38 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 26 +- .../kernel/conv/conv2d/transpose_direct.rs | 20 +- .../src/kernel/conv/conv2d/tune/conv2d.rs | 24 +- .../conv/conv2d/tune/conv_transpose2d.rs | 24 +- crates/burn-jit/src/kernel/conv/conv3d.rs | 20 +- .../src/kernel/conv/conv_transpose3d.rs | 20 +- .../burn-jit/src/kernel/conv/deform_conv2d.rs | 31 +- .../kernel/conv/deform_conv_transpose2d.rs | 79 +-- crates/burn-jit/src/kernel/index/flip.rs | 18 +- crates/burn-jit/src/kernel/index/gather.rs | 14 +- .../burn-jit/src/kernel/index/repeat_dim.rs | 10 +- crates/burn-jit/src/kernel/index/scatter.rs | 14 +- crates/burn-jit/src/kernel/index/select.rs | 12 +- .../src/kernel/index/select_assign.rs | 12 +- crates/burn-jit/src/kernel/index/slice.rs | 20 +- .../burn-jit/src/kernel/index/slice_assign.rs | 10 +- .../burn-jit/src/kernel/interpolate/base.rs | 21 +- .../src/kernel/interpolate/bicubic.rs | 10 +- .../src/kernel/interpolate/bilinear.rs | 10 +- .../src/kernel/interpolate/nearest.rs | 10 +- .../kernel/interpolate/nearest_backward.rs | 10 +- crates/burn-jit/src/kernel/mask/base.rs | 16 +- crates/burn-jit/src/kernel/mask/mask_fill.rs | 34 +- crates/burn-jit/src/kernel/mask/mask_where.rs | 46 +- crates/burn-jit/src/kernel/matmul/base.rs | 12 +- crates/burn-jit/src/kernel/matmul/simple.rs | 20 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 62 +- crates/burn-jit/src/kernel/matmul/utils.rs | 13 +- .../src/kernel/pool/adaptive_avg_pool2d.rs | 10 +- .../pool/adaptive_avg_pool2d_backward.rs | 11 +- crates/burn-jit/src/kernel/pool/avg_pool2d.rs | 10 +- .../src/kernel/pool/avg_pool2d_backward.rs | 12 +- crates/burn-jit/src/kernel/pool/max_pool2d.rs | 24 +- .../src/kernel/pool/max_pool2d_backward.rs | 16 +- crates/burn-jit/src/kernel/prng/base.rs | 6 +- crates/burn-jit/src/kernel/prng/bernoulli.rs | 2 +- crates/burn-jit/src/kernel/prng/normal.rs | 2 +- crates/burn-jit/src/kernel/prng/uniform.rs | 6 +- .../src/kernel/quantization/dequantize.rs | 35 +- .../src/kernel/quantization/quantize.rs | 39 +- crates/burn-jit/src/kernel/reduce/base.rs | 10 +- .../src/kernel/reduce/naive/kernel.rs | 8 +- crates/burn-jit/src/kernel/reduce/prod.rs | 10 +- .../src/kernel/reduce/shared/kernel.rs | 8 +- .../src/kernel/reduce/subcube/kernel.rs | 8 +- crates/burn-jit/src/kernel/reduce/sum.rs | 10 +- .../burn-jit/src/kernel/reduce/tune/base.rs | 10 +- crates/burn-jit/src/kernel/reduce/tune/key.rs | 2 +- crates/burn-jit/src/kernel/unary.rs | 12 +- crates/burn-jit/src/ops/base.rs | 51 +- crates/burn-jit/src/ops/bool_ops.rs | 24 +- crates/burn-jit/src/ops/float_ops.rs | 548 ++++++++++++------ crates/burn-jit/src/ops/int_ops.rs | 100 ++-- crates/burn-jit/src/ops/module_ops.rs | 46 +- crates/burn-jit/src/ops/numeric.rs | 66 +-- crates/burn-jit/src/ops/qtensor.rs | 10 +- crates/burn-jit/src/template/base.rs | 2 +- crates/burn-jit/src/tensor/base.rs | 170 +++++- crates/burn-jit/src/tensor/qtensor.rs | 59 +- crates/burn-jit/src/tests/mask_fill.rs | 12 +- crates/burn-jit/src/tests/mask_where.rs | 32 +- crates/burn-jit/src/tests/reduce.rs | 10 +- .../burn-ndarray/src/ops/adaptive_avgpool.rs | 2 +- crates/burn-ndarray/src/ops/avgpool.rs | 2 +- crates/burn-ndarray/src/ops/base.rs | 3 +- crates/burn-ndarray/src/ops/bool_tensor.rs | 6 +- crates/burn-ndarray/src/ops/conv.rs | 2 +- crates/burn-ndarray/src/ops/deform_conv.rs | 5 +- crates/burn-ndarray/src/ops/int_tensor.rs | 5 +- crates/burn-ndarray/src/ops/interpolate.rs | 2 +- crates/burn-ndarray/src/ops/matmul.rs | 2 +- crates/burn-ndarray/src/ops/maxpool.rs | 2 +- crates/burn-ndarray/src/ops/padding.rs | 2 +- crates/burn-ndarray/src/ops/qtensor.rs | 2 +- crates/burn-ndarray/src/ops/tensor.rs | 6 +- crates/burn-ndarray/src/tensor.rs | 20 +- crates/burn-remote/src/client/runner.rs | 2 +- crates/burn-router/src/backend.rs | 9 +- crates/burn-router/src/client/base.rs | 4 +- crates/burn-router/src/ops/op_bool.rs | 6 +- crates/burn-router/src/ops/op_float.rs | 27 +- crates/burn-router/src/ops/op_int.rs | 8 +- crates/burn-router/src/runner.rs | 15 +- crates/burn-router/src/tensor.rs | 16 +- crates/burn-router/src/types.rs | 6 +- crates/burn-tch/src/ops/base.rs | 2 +- crates/burn-tch/src/ops/bool_tensor.rs | 8 +- crates/burn-tch/src/ops/int_tensor.rs | 8 +- crates/burn-tch/src/ops/module.rs | 10 +- crates/burn-tch/src/ops/qtensor.rs | 2 +- crates/burn-tch/src/ops/tensor.rs | 8 +- crates/burn-tch/src/tensor.rs | 42 +- crates/burn-tensor/src/tensor/api/base.rs | 46 +- crates/burn-tensor/src/tensor/api/chunk.rs | 4 +- crates/burn-tensor/src/tensor/api/float.rs | 4 + crates/burn-tensor/src/tensor/api/kind.rs | 29 +- crates/burn-tensor/src/tensor/api/narrow.rs | 4 +- crates/burn-tensor/src/tensor/api/split.rs | 4 +- crates/burn-tensor/src/tensor/backend/base.rs | 16 +- crates/burn-tensor/src/tensor/element/base.rs | 31 + .../burn-tensor/src/tensor/ops/activation.rs | 3 +- .../burn-tensor/src/tensor/ops/bool_tensor.rs | 23 +- .../burn-tensor/src/tensor/ops/int_tensor.rs | 33 +- .../src/tensor/ops/modules/base.rs | 10 +- .../burn-tensor/src/tensor/ops/modules/cat.rs | 11 +- .../src/tensor/ops/modules/conv.rs | 120 ++-- .../src/tensor/ops/modules/pool.rs | 30 +- .../src/tensor/ops/modules/repeat_dim.rs | 4 +- .../src/tensor/ops/modules/unfold.rs | 6 +- crates/burn-tensor/src/tensor/ops/tensor.rs | 27 +- crates/burn-tensor/src/tests/mod.rs | 4 + crates/burn-tensor/src/tests/primitive.rs | 45 ++ examples/custom-cubecl-kernel/src/backward.rs | 12 +- examples/custom-cubecl-kernel/src/forward.rs | 17 +- .../custom-image-dataset/src/inference.rs | 31 + examples/custom-wgpu-kernel/src/backward.rs | 12 +- examples/custom-wgpu-kernel/src/forward.rs | 11 +- 149 files changed, 2225 insertions(+), 1627 deletions(-) create mode 100644 crates/burn-tensor/src/tests/primitive.rs create mode 100644 examples/custom-image-dataset/src/inference.rs diff --git a/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md b/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md index 5c2dfd868e..4dad8cb7e9 100644 --- a/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md +++ b/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md @@ -171,8 +171,14 @@ impl Backend for JitBackend()); // Create the output tensor primitive. - let output = - JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); + // Create the output tensor primitive. + let output = JitTensor::new_contiguous( + lhs.client.clone(), + lhs.device.clone(), + shape_out, + buffer, + F::dtype(), + ); // Declare the wgsl workgroup with the number of cubes in x, y and z. let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; @@ -186,10 +192,10 @@ impl Backend for JitBackend(1), + rhs.as_tensor_arg::(1), + bias.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); // Return the output tensor. @@ -251,12 +257,12 @@ impl Backend for Autodiff { // Set our state. let (lhs_state, rhs_state, output, shape_bias) = ops.state; - let lhs = checkpointer.retrieve_node_output(lhs_state); - let rhs = checkpointer.retrieve_node_output(rhs_state); + let lhs: FloatTensor = checkpointer.retrieve_node_output(lhs_state); + let rhs: FloatTensor = checkpointer.retrieve_node_output(rhs_state); // Fetch shapes of our tensor to support broadcasting. - let shape_lhs = B::float_shape(&lhs); - let shape_rhs = B::float_shape(&rhs); + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); // Compute the gradient of the output using the already existing `relu_backward` // function in the basic Burn backend trait. @@ -314,7 +320,7 @@ impl Backend for Autodiff { // compute bound operation. let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); - let bias_shape = B::float_shape(&bias.primitive); + let bias_shape = bias.primitive.shape(); let output = B::fused_matmul_add_relu( lhs.primitive.clone(), diff --git a/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md b/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md index c4c6b95016..69ca45ed97 100644 --- a/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md +++ b/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md @@ -239,14 +239,19 @@ impl Backend for JitBackend { .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. - let output = - JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); + let output = JitTensor::new_contiguous( + lhs.client.clone(), + lhs.device.clone(), + shape_out, + buffer, + F::dtype(), + ); // Create the kernel. let kernel = FusedMatmulAddRelu::::new(cube_dim); // Build info buffer with tensor information needed by the kernel, such as shapes and strides. - let info = build_info(&[&lhs, &rhs, &output]); + let info = build_info::<_, F>(&[&lhs, &rhs, &output]); let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); // Declare the wgsl workgroup with the number of cubes in x, y and z. @@ -331,12 +336,12 @@ impl Backend for Autodiff { // Set our state. let (lhs_state, rhs_state, output, shape_bias) = ops.state; - let lhs = checkpointer.retrieve_node_output(lhs_state); - let rhs = checkpointer.retrieve_node_output(rhs_state); + let lhs: FloatTensor = checkpointer.retrieve_node_output(lhs_state); + let rhs: FloatTensor = checkpointer.retrieve_node_output(rhs_state); // Fetch shapes of our tensor to support broadcasting. - let shape_lhs = B::float_shape(&lhs); - let shape_rhs = B::float_shape(&rhs); + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); // Compute the gradient of the output using the already existing `relu_backward` // function in the basic Burn backend trait. @@ -392,7 +397,7 @@ impl Backend for Autodiff { // during the backward pass. Here we choose to save it in the state because it's a compute bound operation. let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); - let bias_shape = B::float_shape(&bias.primitive); + let bias_shape = bias.primitive.shape(); let output = B::fused_matmul_add_relu( lhs.primitive.clone(), diff --git a/crates/burn-autodiff/src/grads.rs b/crates/burn-autodiff/src/grads.rs index b1547ed656..dd4d5613a4 100644 --- a/crates/burn-autodiff/src/grads.rs +++ b/crates/burn-autodiff/src/grads.rs @@ -1,4 +1,4 @@ -use burn_tensor::{backend::Backend, container::TensorContainer, ops::FloatTensor}; +use burn_tensor::{backend::Backend, container::TensorContainer, ops::FloatTensor, TensorMetadata}; use crate::{ graph::{NodeRef, Requirement}, @@ -22,7 +22,7 @@ impl Gradients { }; gradients.register::( root_node.id, - B::float_ones(B::float_shape(&root_tensor), &B::float_device(&root_tensor)), + B::float_ones(root_tensor.shape(), &B::float_device(&root_tensor)), ); gradients } diff --git a/crates/burn-autodiff/src/ops/base.rs b/crates/burn-autodiff/src/ops/base.rs index aed4ab8171..575c8b71fc 100644 --- a/crates/burn-autodiff/src/ops/base.rs +++ b/crates/burn-autodiff/src/ops/base.rs @@ -10,7 +10,7 @@ use crate::{ graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step}, tensor::AutodiffTensor, }; -use burn_tensor::{backend::Backend, ops::FloatTensor, Shape}; +use burn_tensor::{backend::Backend, ops::FloatTensor, Shape, TensorMetadata}; use std::marker::PhantomData; /// Operation in preparation. @@ -292,7 +292,7 @@ impl Step for UntrackedOpsStep { /// If broadcasting happened during the forward pass, the gradients will be sum along the /// broadcasted dimension. pub fn broadcast_shape(mut grad: FloatTensor, shape: &Shape) -> FloatTensor { - let shape_grad = B::float_shape(&grad); + let shape_grad = grad.shape(); let ndims = shape_grad.num_dims(); for i in 0..ndims { diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index 1b40a1af93..ef9f4c73df 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -11,10 +11,6 @@ impl BoolTensorOps for Autodiff { B::bool_from_data(data, device) } - fn bool_shape(tensor: &BoolTensor) -> Shape { - B::bool_shape(tensor) - } - async fn bool_into_data(tensor: BoolTensor) -> TensorData { B::bool_into_data(tensor).await } diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 5f3c80199e..4aad98bb46 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -11,10 +11,6 @@ impl IntTensorOps for Autodiff { B::int_from_data(data, device) } - fn int_shape(tensor: &IntTensor) -> Shape { - B::int_shape(tensor) - } - async fn int_into_data(tensor: IntTensor) -> TensorData { B::int_into_data(tensor).await } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 5764f3a003..12dc8cf90d 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -18,7 +18,7 @@ use crate::{ use burn_tensor::{ backend::Backend, ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, - Device, ElementConversion, Shape, TensorData, + Device, ElementConversion, Shape, TensorData, TensorMetadata, }; use super::maxmin::MaxMinDim; @@ -40,7 +40,7 @@ fn unsqueeze_like( */ let ndims_out = shape.num_dims(); - let shape = B::float_shape(&tensor); + let shape = tensor.shape(); let ndims_in = shape.num_dims(); let mut dims = vec![1; ndims_out]; @@ -71,10 +71,6 @@ impl FloatTensorOps for Autodiff AutodiffTensor::new(B::float_ones(shape, device)) } - fn float_shape(tensor: &FloatTensor) -> Shape { - B::float_shape(&tensor.primitive) - } - async fn float_into_data(tensor: FloatTensor) -> TensorData { B::float_into_data(tensor.primitive).await } @@ -154,10 +150,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(preps) => preps.finish( - ( - B::float_shape(&lhs.primitive), - B::float_shape(&rhs.primitive), - ), + (lhs.primitive.shape(), rhs.primitive.shape()), B::float_add(lhs.primitive, rhs.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_add(lhs.primitive, rhs.primitive)), @@ -226,10 +219,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(preps) => preps.finish( - ( - B::float_shape(&lhs.primitive), - B::float_shape(&rhs.primitive), - ), + (lhs.primitive.shape(), rhs.primitive.shape()), B::float_sub(lhs.primitive, rhs.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_sub(lhs.primitive, rhs.primitive)), @@ -864,7 +854,7 @@ impl FloatTensorOps for Autodiff let ndims_out = shape.num_dims(); unary::(ops.parents, ops.node, grads, |grad| { - let shape_grad = B::float_shape(&grad); + let shape_grad = grad.shape(); let mut grad = grad; for i in 0..ndims_out { @@ -886,7 +876,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => prep.finish( - (B::float_shape(&tensor.primitive), shape.clone()), + (tensor.primitive.shape(), shape.clone()), B::float_reshape(tensor.primitive, shape), ), OpsKind::UnTracked(prep) => prep.finish(B::float_reshape(tensor.primitive, shape)), @@ -928,7 +918,7 @@ impl FloatTensorOps for Autodiff ( dim, indices.clone(), - B::float_shape(&tensor.primitive), + tensor.primitive.shape(), B::float_device(&tensor.primitive), ), B::float_gather(dim, tensor.primitive, indices), @@ -985,8 +975,8 @@ impl FloatTensorOps for Autodiff ( dim, indices.clone(), - B::float_shape(&tensor.primitive), - B::float_shape(&value.primitive), + tensor.primitive.shape(), + value.primitive.shape(), B::float_device(&value.primitive), ), B::float_scatter(dim, tensor.primitive, indices, value.primitive), @@ -1052,7 +1042,7 @@ impl FloatTensorOps for Autodiff ( dim, indices.clone(), - B::float_shape(&tensor.primitive), + tensor.primitive.shape(), B::float_device(&tensor.primitive), ), B::float_select(tensor.primitive, dim, indices), @@ -1185,7 +1175,7 @@ impl FloatTensorOps for Autodiff OpsKind::Tracked(prep) => prep.finish( ( ranges.to_vec(), - B::float_shape(&tensor.primitive), + tensor.primitive.shape(), B::float_device(&tensor.primitive), ), B::float_slice(tensor.primitive, ranges), @@ -1258,7 +1248,7 @@ impl FloatTensorOps for Autodiff OpsKind::Tracked(prep) => prep.finish( ( ranges.to_vec(), - B::float_shape(&value.primitive), + value.primitive.shape(), B::float_device(&value.primitive), ), B::float_slice_assign(tensor.primitive, ranges, value.primitive), @@ -1319,8 +1309,8 @@ impl FloatTensorOps for Autodiff OpsKind::Tracked(prep) => prep.finish( ( mask.clone(), - B::float_shape(&tensor.primitive), - B::float_shape(&source.primitive), + tensor.primitive.shape(), + source.primitive.shape(), B::float_device(&source.primitive), ), B::float_mask_where(tensor.primitive, mask, source.primitive), @@ -1454,17 +1444,16 @@ impl FloatTensorOps for Autodiff let ones = B::float_ones(shape, &B::float_device(&grad)); let val = B::float_mul_scalar(ones, val.elem()); - let grad = unsqueeze_like::(grad, B::float_shape(&val)); + let grad = unsqueeze_like::(grad, val.shape()); B::float_mul(val, grad) }); } } match Mean.prepare::([tensor.node]).compute_bound().stateful() { - OpsKind::Tracked(prep) => prep.finish( - B::float_shape(&tensor.primitive), - B::float_mean(tensor.primitive), - ), + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.shape(), B::float_mean(tensor.primitive)) + } OpsKind::UnTracked(prep) => prep.finish(B::float_mean(tensor.primitive)), } } @@ -1485,17 +1474,16 @@ impl FloatTensorOps for Autodiff unary::(ops.parents, ops.node, grads, |grad| { let val = B::float_ones(ops.state, &B::float_device(&grad)); - let grad = unsqueeze_like::(grad, B::float_shape(&val)); + let grad = unsqueeze_like::(grad, val.shape()); B::float_mul(val, grad) }); } } match Sum.prepare::([tensor.node]).compute_bound().stateful() { - OpsKind::Tracked(prep) => prep.finish( - B::float_shape(&tensor.primitive), - B::float_sum(tensor.primitive), - ), + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.shape(), B::float_sum(tensor.primitive)) + } OpsKind::UnTracked(prep) => prep.finish(B::float_sum(tensor.primitive)), } } @@ -1532,7 +1520,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => prep.finish( - (B::float_shape(&tensor.primitive), dim), + (tensor.primitive.shape(), dim), B::float_mean_dim(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_mean_dim(tensor.primitive, dim)), @@ -1569,7 +1557,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => prep.finish( - (B::float_shape(&tensor.primitive), dim), + (tensor.primitive.shape(), dim), B::float_sum_dim(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_sum_dim(tensor.primitive, dim)), @@ -1980,10 +1968,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(preps) => preps.finish( - ( - B::float_shape(&tensor.primitive), - B::float_device(&tensor.primitive), - ), + (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_round(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_round(tensor.primitive)), @@ -2019,10 +2004,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(preps) => preps.finish( - ( - B::float_shape(&tensor.primitive), - B::float_device(&tensor.primitive), - ), + (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_floor(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)), @@ -2058,10 +2040,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(preps) => preps.finish( - ( - B::float_shape(&tensor.primitive), - B::float_device(&tensor.primitive), - ), + (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_floor(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)), @@ -2125,7 +2104,7 @@ impl FloatTensorOps for Autodiff impl Step for CatStep { fn step(self: Box, grads: &mut Gradients, _checkpointer: &mut Checkpointer) { let grad = grads.consume::(&self.output); - let ranges: Vec<_> = B::float_shape(&grad).dims.iter().map(|v| 0..*v).collect(); + let ranges: Vec<_> = grad.shape().dims.iter().map(|v| 0..*v).collect(); let mut current_index = 0; @@ -2162,7 +2141,7 @@ impl FloatTensorOps for Autodiff let mut dim_sizes = Vec::with_capacity(tensors.len()); tensors.into_iter().for_each(|tensor| { - dim_sizes.push(B::float_shape(&tensor.primitive).dims[dim]); + dim_sizes.push(tensor.primitive.shape().dims[dim]); nodes.push(tensor.node); primitives.push(tensor.primitive); }); @@ -2201,7 +2180,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => { - let shape = B::float_shape(&tensor.primitive); + let shape = tensor.primitive.shape(); let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim); prep.finish((index, shape), tensor) } @@ -2218,7 +2197,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => { - let shape = B::float_shape(&tensor.primitive); + let shape = tensor.primitive.shape(); let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim); let tensor = prep.finish((index.clone(), shape), tensor); @@ -2239,7 +2218,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => { - let shape = B::float_shape(&tensor.primitive); + let shape = tensor.primitive.shape(); let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim); prep.finish((index, shape), tensor) } @@ -2256,7 +2235,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => { - let shape = B::float_shape(&tensor.primitive); + let shape = tensor.primitive.shape(); let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim); let tensor = prep.finish((index.clone(), shape), tensor); @@ -2425,7 +2404,7 @@ impl FloatTensorOps for Autodiff } unary::(ops.parents, ops.node, grads, |grad| { - let shape_grad = B::float_shape(&grad); + let shape_grad = grad.shape(); let mut grad = grad; #[allow(clippy::needless_range_loop)] @@ -2448,7 +2427,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => prep.finish( - (B::float_shape(&tensor.primitive), shape.clone()), + (tensor.primitive.shape(), shape.clone()), B::float_expand(tensor.primitive, shape), ), OpsKind::UnTracked(prep) => prep.finish(B::float_expand(tensor.primitive, shape)), @@ -2462,7 +2441,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => { - let shape = B::float_shape(&tensor.primitive); + let shape = tensor.primitive.shape(); let (tensor, indices) = B::float_sort_with_indices(tensor.primitive, dim, descending); prep.finish((indices, shape), tensor) @@ -2484,7 +2463,7 @@ impl FloatTensorOps for Autodiff .stateful() { OpsKind::Tracked(prep) => { - let shape = B::float_shape(&tensor.primitive); + let shape = tensor.primitive.shape(); let (tensor, indices) = B::float_sort_with_indices(tensor.primitive, dim, descending); let tensor = prep.finish((indices.clone(), shape), tensor); @@ -2574,8 +2553,8 @@ enum BinaryOpsBroadcast { impl BinaryOpsBroadcast { fn new(lhs: &B::FloatTensorPrimitive, rhs: &B::FloatTensorPrimitive) -> Self { - let shape_lhs = B::float_shape(lhs); - let shape_rhs = B::float_shape(rhs); + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); let ndims = shape_lhs.num_dims(); for i in 0..ndims { diff --git a/crates/burn-autodiff/src/tensor.rs b/crates/burn-autodiff/src/tensor.rs index b51d75d06a..1fa7fe0c8e 100644 --- a/crates/burn-autodiff/src/tensor.rs +++ b/crates/burn-autodiff/src/tensor.rs @@ -6,7 +6,7 @@ use crate::{ graph::{ComputingProperty, Node, NodeID, NodeRef, Requirement, Step}, runtime::{AutodiffClient, AutodiffClientImpl}, }; -use burn_tensor::backend::Backend; +use burn_tensor::{backend::Backend, TensorMetadata}; #[derive(Debug, Clone)] pub struct AutodiffTensor { @@ -15,6 +15,16 @@ pub struct AutodiffTensor { pub rc: NodeRefCount, } +impl TensorMetadata for AutodiffTensor { + fn dtype(&self) -> burn_tensor::DType { + self.primitive.dtype() + } + + fn shape(&self) -> burn_tensor::Shape { + self.primitive.shape() + } +} + pub type NodeRefCount = Arc; #[derive(new, Debug)] diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index 22bd30a873..bd817a1809 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use burn_tensor::{backend::Backend, Element, Shape, TensorData}; +use burn_tensor::{backend::Backend, Element, Shape, TensorData, TensorMetadata}; use candle_core::WithDType; use half::{bf16, f16}; diff --git a/crates/burn-candle/src/ops/bool_tensor.rs b/crates/burn-candle/src/ops/bool_tensor.rs index 8c5b6cb135..f03490e94a 100644 --- a/crates/burn-candle/src/ops/bool_tensor.rs +++ b/crates/burn-candle/src/ops/bool_tensor.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor}, - Device, Shape, TensorData, + Device, Shape, TensorData, TensorMetadata, }; use crate::{ @@ -15,10 +15,6 @@ impl BoolTensorOps for Candle< super::base::empty(shape, device, candle_core::DType::U8) } - fn bool_shape(tensor: &BoolTensor) -> Shape { - super::base::shape(tensor) - } - async fn bool_into_data(tensor: BoolTensor) -> TensorData { let x: Vec = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(); let y = x.iter().map(|b| !matches!(b, 0)).collect(); diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 43998abaac..4ae0c53de7 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -15,10 +15,6 @@ impl IntTensorOps for Candle) -> Shape { - super::base::shape(tensor) - } - async fn int_into_data(tensor: IntTensor) -> TensorData { super::base::into_data(tensor) } diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index bedbf5b4d7..144feb7459 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -52,10 +52,6 @@ impl FloatTensorOps for Candle } } - fn float_shape(tensor: &CandleTensor) -> Shape { - super::base::shape(tensor) - } - async fn float_into_data(tensor: CandleTensor) -> TensorData { super::base::into_data(tensor) } diff --git a/crates/burn-candle/src/tensor.rs b/crates/burn-candle/src/tensor.rs index e512823d2b..974ca0382c 100644 --- a/crates/burn-candle/src/tensor.rs +++ b/crates/burn-candle/src/tensor.rs @@ -1,6 +1,6 @@ use burn_tensor::{ quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, - Element, Shape, TensorData, + DType, Element, Shape, TensorData, TensorMetadata, }; use crate::{element::CandleElement, CandleDevice}; @@ -11,6 +11,26 @@ pub struct CandleTensor { pub(crate) tensor: candle_core::Tensor, } +impl TensorMetadata for CandleTensor { + fn dtype(&self) -> DType { + match self.tensor.dtype() { + // NOTE: bool tensors are stored as u32, we currently make this assumption + // since `TensorMetadata::dtype()` is used for display purposes only at this time. + candle_core::DType::U8 => DType::Bool, + candle_core::DType::U32 => DType::U32, + candle_core::DType::I64 => DType::I64, + candle_core::DType::BF16 => DType::BF16, + candle_core::DType::F16 => DType::F16, + candle_core::DType::F32 => DType::F32, + candle_core::DType::F64 => DType::F64, + } + } + + fn shape(&self) -> Shape { + Shape::from(self.tensor.dims().to_vec()) + } +} + impl CandleTensor { /// Create a new tensor. pub fn new(tensor: candle_core::Tensor) -> Self { @@ -36,10 +56,6 @@ impl CandleTensor { ); Self::new(tensor.unwrap()) } - - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.tensor.dims().to_vec()) - } } /// A quantized tensor for the candle backend. @@ -61,3 +77,13 @@ impl QTensorPrimitive for CandleQTensor { todo!() } } + +impl TensorMetadata for CandleQTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } +} diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index 4f2dbb9865..a09c192d83 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -1,3 +1,37 @@ +use burn_tensor::repr::{BinaryOperationDescription, TensorDescription}; + +#[derive(Debug)] +pub enum BinaryOpError { + #[allow(dead_code)] + /// Binary op data type mismatch. + DTypeMismatch { + lhs: burn_tensor::DType, + rhs: burn_tensor::DType, + }, +} + +// Until we have floating point type promotion, check that lhs and rhs dtypes are the same. +pub(crate) fn check_binary_op( + desc: BinaryOperationDescription, +) -> Result { + check_binary_op_types(&desc.lhs, &desc.rhs)?; + Ok(desc) +} + +pub(crate) fn check_binary_op_types( + lhs: &TensorDescription, + rhs: &TensorDescription, +) -> Result<(), BinaryOpError> { + if lhs.dtype != rhs.dtype { + Err(BinaryOpError::DTypeMismatch { + lhs: lhs.dtype, + rhs: rhs.dtype, + }) + } else { + Ok(()) + } +} + #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_ops { @@ -5,12 +39,20 @@ macro_rules! binary_float_ops { $name:ident, $ops:expr ) => { - #[derive(new)] struct $name { desc: BinaryOperationDescription, _b: PhantomData, } + impl $name { + fn new(desc: BinaryOperationDescription) -> Self { + Self { + desc: $crate::ops::binary::check_binary_op(desc).unwrap(), + _b: PhantomData, + } + } + } + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); @@ -55,12 +97,20 @@ macro_rules! binary_int_cmp_ops { $name:ident, $ops:expr ) => { - #[derive(new)] struct $name { desc: BinaryOperationDescription, _b: PhantomData, } + impl $name { + fn new(desc: BinaryOperationDescription) -> Self { + Self { + desc: $crate::ops::binary::check_binary_op(desc).unwrap(), + _b: PhantomData, + } + } + } + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 510c0254ba..baa5169db3 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -35,10 +35,6 @@ impl BoolTensorOps for Fusion { ) } - fn bool_shape(tensor: &BoolTensor) -> Shape { - tensor.shape() - } - async fn bool_into_data(tensor: BoolTensor) -> TensorData { tensor.bool_into_data::().await } @@ -46,7 +42,7 @@ impl BoolTensorOps for Fusion { fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor { let client = get_client::(&device.clone()); let tensor = B::bool_from_data(data, device); - let shape = B::bool_shape(&tensor); + let shape = burn_tensor::TensorMetadata::shape(&tensor); client.register_tensor( B::bool_tensor_handle(tensor), @@ -194,7 +190,7 @@ impl BoolTensorOps for Fusion { } } - let ndims = tensor.shape().num_dims(); + let ndims = burn_tensor::TensorMetadata::shape(&tensor).num_dims(); let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); for i in shape.len()..ndims { diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 0b447da05e..12bc65534e 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -1,7 +1,9 @@ use crate::{ binary_float_cmp_ops, binary_float_ops, client::FusionClient, - get_client, scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, + get_client, + ops::binary::check_binary_op_types, + scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, stream::{execution::Operation, StreamId}, unary_float_ops, Fusion, FusionBackend, }; @@ -16,7 +18,7 @@ impl FloatTensorOps for Fusion { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(&device.clone()); let tensor = B::float_from_data(data, device); - let shape = B::float_shape(&tensor); + let shape = burn_tensor::TensorMetadata::shape(&tensor); client.register_tensor( B::float_tensor_handle(tensor), @@ -168,10 +170,6 @@ impl FloatTensorOps for Fusion { out } - fn float_shape(tensor: &FloatTensor) -> Shape { - tensor.shape() - } - async fn float_into_data(tensor: FloatTensor) -> TensorData { tensor.into_data::().await } @@ -216,6 +214,7 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); @@ -226,10 +225,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::IntoInt(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::IntoInt(desc.clone())), IntoIntOps::::new(desc), ); @@ -254,10 +250,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs.client.tensor_uninitialized( - binary_ops_shape(&lhs.shape, &rhs.shape), - B::FloatElem::dtype(), - ); + let dtype = lhs.dtype; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), lhs.dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -268,7 +264,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Add(desc.clone()), ), AddOps::::new(desc), @@ -281,6 +277,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(AddOps, B::float_add_scalar); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); @@ -293,7 +290,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::AddScalar(desc.clone()), ), AddOps::::new(desc), @@ -323,6 +320,7 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); @@ -336,7 +334,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Clamp(desc.clone()), ), ClampOps::::new(desc), @@ -350,10 +348,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs.client.tensor_uninitialized( - binary_ops_shape(&lhs.shape, &rhs.shape), - B::FloatElem::dtype(), - ); + let dtype = lhs.dtype; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), lhs.dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -363,7 +361,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Sub(desc.clone()), ), SubOps::::new(desc), @@ -376,6 +374,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(SubOps, B::float_sub_scalar); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); @@ -388,7 +387,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::SubScalar(desc.clone()), ), SubOps::::new(desc), @@ -402,10 +401,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs.client.tensor_uninitialized( - binary_ops_shape(&lhs.shape, &rhs.shape), - B::FloatElem::dtype(), - ); + let dtype = lhs.dtype; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), lhs.dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -415,7 +414,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Mul(desc.clone()), ), MulOps::::new(desc), @@ -428,6 +427,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(MulOps, B::float_mul_scalar); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); @@ -440,7 +440,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MulScalar(desc.clone()), ), MulOps::::new(desc), @@ -454,10 +454,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs.client.tensor_uninitialized( - binary_ops_shape(&lhs.shape, &rhs.shape), - B::FloatElem::dtype(), - ); + let dtype = lhs.dtype; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), lhs.dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -467,7 +467,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Div(desc.clone()), ), DivOps::::new(desc), @@ -480,6 +480,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(DivOps, B::float_div_scalar); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); @@ -492,7 +493,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::DivScalar(desc.clone()), ), DivOps::::new(desc), @@ -506,10 +507,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs.client.tensor_uninitialized( - binary_ops_shape(&lhs.shape, &rhs.shape), - B::FloatElem::dtype(), - ); + let dtype = lhs.dtype; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), lhs.dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -519,7 +520,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Rem(desc.clone()), ), ModOps::::new(desc), @@ -532,6 +533,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(ModOps, B::float_remainder_scalar); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); @@ -544,7 +546,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::RemScalar(desc.clone()), ), ModOps::::new(desc), @@ -558,8 +560,9 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; + let dtype = lhs.dtype; let mut shape = binary_ops_shape(&lhs.shape, &rhs.shape); - let ndims = lhs.shape().num_dims(); + let ndims = burn_tensor::TensorMetadata::shape(&lhs).num_dims(); shape[ndims - 2] = lhs.shape[ndims - 2]; shape[ndims - 1] = rhs.shape[ndims - 1]; @@ -575,10 +578,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Matmul(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Matmul(desc.clone())), MatmulOps::::new(desc), ); @@ -681,10 +681,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = indices.stream; + let dtype = tensor.dtype; let shape: Vec = indices.shape.clone(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = GatherOperationDescription { tensor: tensor.into_description(), @@ -695,7 +694,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Gather(desc.clone()), ), GatherOps::::new(desc), @@ -731,10 +730,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = indices.stream; let stream_3 = value.stream; + let dtype = tensor.dtype; let shape: Vec = tensor.shape.clone(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScatterOperationDescription { tensor: tensor.into_description(), @@ -743,11 +741,12 @@ impl FloatTensorOps for Fusion { value: value.into_description(), out: out.to_description_out(), }; - + // Check that both float tensors have the same type + check_binary_op_types(&desc.tensor, &desc.value).unwrap(); out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Scatter(desc.clone()), ), ScatterOps::::new(desc), @@ -780,11 +779,10 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = indices.stream; + let dtype = tensor.dtype; let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SelectOperationDescription { tensor: tensor.into_description(), dim, @@ -794,7 +792,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Select(desc.clone()), ), SelectOps::::new(desc), @@ -830,10 +828,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = indices.stream; let stream_3 = value.stream; + let dtype = tensor.dtype; let shape: Vec = tensor.shape.clone(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SelectAssignOperationDescription { tensor: tensor.into_description(), @@ -842,10 +839,12 @@ impl FloatTensorOps for Fusion { value: value.into_description(), out: out.to_description_out(), }; + // Check that both float tensors have the same type + check_binary_op_types(&desc.tensor, &desc.value).unwrap(); out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::SelectAssign(desc.clone()), ), SelectAssignOps::::new(desc), @@ -871,16 +870,15 @@ impl FloatTensorOps for Fusion { } } let stream = tensor.stream; - let ndims = tensor.shape().num_dims(); + let dtype = tensor.dtype; + let ndims = burn_tensor::TensorMetadata::shape(&tensor).num_dims(); let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); for i in shape.len()..ndims { shape.push(tensor.shape[i]); } - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -920,10 +918,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = value.stream; + let dtype = tensor.dtype; let shape: Vec = tensor.shape.clone(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), @@ -931,6 +928,8 @@ impl FloatTensorOps for Fusion { value: value.into_description(), out: out.to_description_out(), }; + // Check that both float tensors have the same type + check_binary_op_types(&desc.tensor, &desc.value).unwrap(); out.client.register( vec![stream_1, stream_2], OperationDescription::BaseFloat(BaseOperationDescription::SliceAssign(desc.clone())), @@ -966,10 +965,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = mask.stream; let stream_3 = value.stream; + let dtype = tensor.dtype; let shape = binary_ops_shape(&tensor.shape, &mask.shape); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = MaskWhereOperationDescription { tensor: tensor.into_description(), @@ -977,10 +975,12 @@ impl FloatTensorOps for Fusion { mask: mask.into_description(), out: out.to_description_out(), }; + // Check that both float tensors have the same type + check_binary_op_types(&desc.tensor, &desc.value).unwrap(); out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MaskWhere(desc.clone()), ), MaskWhereOps::::new(desc), @@ -1013,10 +1013,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = mask.stream; + let dtype = tensor.dtype; let shape: Vec = tensor.shape.clone(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = MaskFillOperationDescription { tensor: tensor.into_description(), value: value.elem(), @@ -1026,7 +1025,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MaskFill(desc.clone()), ), MaskFillOps::::new(desc), @@ -1062,6 +1061,7 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), DType::Bool); @@ -1074,7 +1074,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::EqualElem(desc.clone()), ), EqualElemOps::::new(desc), @@ -1088,6 +1088,7 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); @@ -1100,7 +1101,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Greater(desc.clone()), ), GreaterOps::::new(desc), @@ -1113,6 +1114,7 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), DType::Bool); @@ -1125,7 +1127,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::GreaterElem(desc.clone()), ), GreaterElemOps::::new(desc), @@ -1139,6 +1141,7 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); @@ -1151,7 +1154,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::GreaterEqual(desc.clone()), ), GreaterEqualOps::::new(desc), @@ -1164,6 +1167,7 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), DType::Bool); @@ -1176,7 +1180,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::GreaterEqualElem(desc.clone()), ), GreaterEqualElemOps::::new(desc), @@ -1190,6 +1194,7 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); @@ -1202,7 +1207,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Lower(desc.clone()), ), LowerOps::::new(desc), @@ -1215,6 +1220,7 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), DType::Bool); @@ -1227,7 +1233,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::LowerElem(desc.clone()), ), LowerElemOps::::new(desc), @@ -1241,6 +1247,7 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); @@ -1253,7 +1260,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::LowerEqual(desc.clone()), ), LowerEqualOps::::new(desc), @@ -1266,6 +1273,7 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem); let stream = lhs.stream; + let dtype = lhs.dtype; let out = lhs .client .tensor_uninitialized(lhs.shape.clone(), DType::Bool); @@ -1278,7 +1286,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::LowerEqualElem(desc.clone()), ), LowerEqualElemOps::::new(desc), @@ -1291,6 +1299,7 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SumOps, B::float_sum, reduce); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client .tensor_uninitialized(vec![1], B::FloatElem::dtype()); @@ -1302,7 +1311,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Sum(desc.clone()), ), SumOps::::new(desc), @@ -1315,6 +1324,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(SumDimOps, B::float_sum_dim, usize, noconvert); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; let out = tensor @@ -1329,7 +1339,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::SumDim(desc.clone()), ), SumDimOps::::new(desc), @@ -1393,6 +1403,7 @@ impl FloatTensorOps for Fusion { unary_float_ops!(MeanOps, B::float_mean, reduce); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client .tensor_uninitialized(vec![1], B::FloatElem::dtype()); @@ -1404,7 +1415,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Mean(desc.clone()), ), MeanOps::::new(desc), @@ -1417,6 +1428,7 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(MeanDimOps, B::float_mean_dim, usize, noconvert); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; let out = tensor @@ -1431,7 +1443,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MeanDim(desc.clone()), ), MeanDimOps::::new(desc), @@ -1444,9 +1456,8 @@ impl FloatTensorOps for Fusion { unary_float_ops!(ExpOps, B::float_exp); let stream = lhs.stream; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let dtype = lhs.dtype; + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = UnaryOperationDescription { input: lhs.into_description(), @@ -1454,10 +1465,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Exp(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Exp(desc.clone())), ExpOps::::new(desc), ); @@ -1468,9 +1476,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(LogOps, B::float_log); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1478,10 +1487,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Log(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Log(desc.clone())), LogOps::::new(desc), ); @@ -1492,9 +1498,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(Log1pOps, B::float_log1p); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1502,10 +1509,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Log1p(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Log1p(desc.clone())), Log1pOps::::new(desc), ); @@ -1516,9 +1520,8 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(PowfOps, B::float_powf_scalar, f32); let stream = lhs.stream; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let dtype = lhs.dtype; + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1527,10 +1530,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::PowfScalar(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::PowfScalar(desc.clone())), PowfOps::::new(desc), ); @@ -1541,9 +1541,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SqrtOps, B::float_sqrt); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1551,10 +1552,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Sqrt(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Sqrt(desc.clone())), SqrtOps::::new(desc), ); @@ -1565,9 +1563,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(AbsOps, B::float_abs); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1576,7 +1575,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Abs(desc.clone()), ), AbsOps::::new(desc), @@ -1589,9 +1588,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(CosOps, B::float_cos); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1599,10 +1599,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Cos(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Cos(desc.clone())), CosOps::::new(desc), ); @@ -1613,9 +1610,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SinOps, B::float_sin); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1623,10 +1621,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Sin(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Sin(desc.clone())), SinOps::::new(desc), ); @@ -1637,9 +1632,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(TanhOps, B::float_tanh); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1647,10 +1643,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Tanh(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Tanh(desc.clone())), TanhOps::::new(desc), ); @@ -1661,19 +1654,17 @@ impl FloatTensorOps for Fusion { unary_float_ops!(Recip, B::float_recip); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Recip(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Recip(desc.clone())), Recip::::new(desc), ); @@ -1684,9 +1675,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(TanhOps, B::float_erf); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1694,10 +1686,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Erf(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Erf(desc.clone())), TanhOps::::new(desc), ); @@ -1739,11 +1728,15 @@ impl FloatTensorOps for Fusion { let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); + // TODO: check dtype let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), dim, out: out.to_description_out(), }; + desc.tensors + .windows(2) + .for_each(|desc| check_binary_op_types(&desc[0], &desc[1]).unwrap()); client.register( streams, OperationDescription::BaseFloat(BaseOperationDescription::Cat(desc.clone())), @@ -1757,6 +1750,7 @@ impl FloatTensorOps for Fusion { scalar_float2int_ops!(ArgMaxOps, B::float_argmax, usize); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; let out = tensor @@ -1771,7 +1765,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::ArgMax(desc.clone()), ), ArgMaxOps::::new(desc), @@ -1800,9 +1794,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] *= times; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, tensor.dtype); let desc = RepeatDimOperationDescription { tensor: tensor.into_description(), @@ -1825,6 +1817,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; + let dtype = tensor.dtype; let out = tensor .client .tensor_uninitialized(shape, B::IntElem::dtype()); @@ -1837,7 +1830,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::ArgMin(desc.clone()), ), ArgMinOps::::new(desc), @@ -1850,9 +1843,8 @@ impl FloatTensorOps for Fusion { unary_float_ops!(MaxOps, B::float_max, reduce); let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1861,7 +1853,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Max(desc.clone()), ), MaxOps::::new(desc), @@ -1875,10 +1867,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); + let dtype = tensor.dtype; shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1888,7 +1879,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MaxDim(desc.clone()), ), MaxDimOps::::new(desc), @@ -1920,8 +1911,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; + let dtype = tensor.dtype; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out = client.tensor_uninitialized(shape.clone(), dtype); let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { @@ -1933,7 +1925,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MaxDimWithIndices(desc.clone()), ), MaxDimWithIndicesOps::::new(desc), @@ -1946,9 +1938,8 @@ impl FloatTensorOps for Fusion { unary_float_ops!(MinOps, B::float_min, reduce); let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1957,7 +1948,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Min(desc.clone()), ), MinOps::::new(desc), @@ -1970,11 +1961,10 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(MinDimOps, B::float_min_dim, usize, noconvert); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1984,7 +1974,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MinDim(desc.clone()), ), MinDimOps::::new(desc), @@ -2014,10 +2004,11 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out = client.tensor_uninitialized(shape.clone(), dtype); let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { @@ -2029,7 +2020,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::MinDimWithIndices(desc.clone()), ), MinDimWithIndicesOps::::new(desc), @@ -2042,11 +2033,11 @@ impl FloatTensorOps for Fusion { binary_float_ops!(PowOps, B::float_powf); let stream_1 = lhs.stream; let stream_2 = rhs.stream; + let dtype = lhs.dtype; - let out = lhs.client.tensor_uninitialized( - binary_ops_shape(&lhs.shape, &rhs.shape), - B::FloatElem::dtype(), - ); + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -2056,7 +2047,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Powf(desc.clone()), ), PowOps::::new(desc), @@ -2085,9 +2076,7 @@ impl FloatTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.iter().map(|x| tensor.shape[*x]).collect(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, tensor.dtype); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -2124,7 +2113,7 @@ impl FloatTensorOps for Fusion { let out = tensor .client - .tensor_uninitialized(shape.dims.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(shape.dims.clone(), tensor.dtype); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -2159,7 +2148,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), tensor.dtype); let desc = FlipOperationDescription { input: tensor.into_description(), @@ -2180,9 +2169,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(RoundOps, B::float_round); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -2190,10 +2180,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Round(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Round(desc.clone())), RoundOps::::new(desc), ); @@ -2204,9 +2191,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(FloorOps, B::float_floor); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -2214,10 +2202,7 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Floor(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Floor(desc.clone())), FloorOps::::new(desc), ); @@ -2228,9 +2213,10 @@ impl FloatTensorOps for Fusion { unary_float_ops!(CeilOps, B::float_ceil); let stream = tensor.stream; + let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -2238,20 +2224,44 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float( - FloatElem::::dtype(), - FloatOperationDescription::Ceil(desc.clone()), - ), + OperationDescription::Float(dtype, FloatOperationDescription::Ceil(desc.clone())), CeilOps::::new(desc), ); out } - fn float_cast( - _tensor: FloatTensor, - _dtype: burn_tensor::FloatDType, - ) -> FloatTensor { - todo!() + fn float_cast(tensor: FloatTensor, dtype: burn_tensor::FloatDType) -> FloatTensor { + #[derive(new)] + struct CastOps { + desc: UnaryOperationDescription, + dtype: burn_tensor::FloatDType, + _b: PhantomData, + } + + impl Operation for CastOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); + let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), dtype.clone().into()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), + CastOps::::new(desc, dtype), + ); + + out } } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 806f749f50..bdb47df02c 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -27,10 +27,6 @@ impl IntTensorOps for Fusion { ) } - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape() - } - async fn int_into_data(tensor: IntTensor) -> TensorData { tensor.int_into_data::().await } @@ -38,7 +34,7 @@ impl IntTensorOps for Fusion { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(&device.clone()); let tensor = B::int_from_data(data, device); - let shape = B::int_shape(&tensor); + let shape = burn_tensor::TensorMetadata::shape(&tensor); let stream = StreamId::current(); client.register_tensor( @@ -121,7 +117,7 @@ impl IntTensorOps for Fusion { } let stream = tensor.stream; - let ndims = tensor.shape().num_dims(); + let ndims = burn_tensor::TensorMetadata::shape(&tensor).num_dims(); let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); for i in shape.len()..ndims { diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 5df3544781..c8b1c09610 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -212,7 +212,9 @@ impl QTensorOps for Fusion { } fn q_shape(tensor: &QuantizedTensor) -> Shape { - tensor.qtensor.shape() + // Conflicting `dtype()` when both `Element` and `TensorMetadata` traits are in + // scope so we use the fully qualified syntax + burn_tensor::TensorMetadata::shape(tensor) } fn q_device(tensor: &QuantizedTensor) -> Device { diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 0144152efb..44d71abd4c 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -7,7 +7,7 @@ use burn_tensor::{ QuantizationParametersDescription, QuantizedTensorDescription, TensorDescription, TensorId, TensorStatus, }, - DType, Shape, TensorData, + DType, Shape, TensorData, TensorMetadata, }; use std::sync::Arc; @@ -58,6 +58,16 @@ impl core::fmt::Debug for FusionTensor { } } +impl TensorMetadata for FusionTensor { + fn dtype(&self) -> DType { + self.dtype + } + + fn shape(&self) -> Shape { + Shape::from(self.shape.clone()) + } +} + impl FusionTensor { pub(crate) fn new( id: Arc, @@ -75,9 +85,6 @@ impl FusionTensor { stream, } } - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.shape.clone()) - } fn status(&self) -> TensorStatus { if Arc::strong_count(&self.id) <= 1 { @@ -197,6 +204,16 @@ impl Clone for QFusionTensor { } } +impl TensorMetadata for QFusionTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } +} + impl QFusionTensor { pub(crate) async fn into_data(self) -> TensorData where diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index b234790758..d83f58f3fe 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -38,10 +38,10 @@ where type FloatElem = F; type IntElem = I; - type FloatTensorPrimitive = JitTensor; - type IntTensorPrimitive = JitTensor; - type BoolTensorPrimitive = JitTensor; - type QuantizedTensorPrimitive = QJitTensor; + type FloatTensorPrimitive = JitTensor; + type IntTensorPrimitive = JitTensor; + type BoolTensorPrimitive = JitTensor; + type QuantizedTensorPrimitive = QJitTensor; type QuantizedEncoding = u32; fn name() -> String { diff --git a/crates/burn-jit/src/bridge.rs b/crates/burn-jit/src/bridge.rs index e751075382..eb215fdcad 100644 --- a/crates/burn-jit/src/bridge.rs +++ b/crates/burn-jit/src/bridge.rs @@ -37,8 +37,13 @@ where >(tensor); // The line below does the backend type cast. - let tensor = - JitTensor::new_contiguous(tensor.client, tensor.device, tensor.shape, tensor.handle); + let tensor = JitTensor::new_contiguous( + tensor.client, + tensor.device, + tensor.shape, + tensor.handle, + FTarget::dtype(), + ); if let Some(device) = &device { to_device(tensor, device) @@ -57,8 +62,13 @@ where FloatElem>, >(tensor); // The line below does the backend type cast. - let tensor = - JitTensor::new_contiguous(tensor.client, tensor.device, tensor.shape, tensor.handle); + let tensor = JitTensor::new_contiguous( + tensor.client, + tensor.device, + tensor.shape, + tensor.handle, + FOrigin::dtype(), + ); if let Some(device) = &device { to_device(tensor, device) diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 227b663e39..7968626e89 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,10 +1,7 @@ use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; use crate::fusion::elemwise::builder::ElementWiseBuilder; use crate::tensor::{JitQuantizationParameters, QJitTensor}; -use crate::{ - element::JitElement, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, - JitRuntime, -}; +use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::quantization::QuantizationScheme; use burn_tensor::repr::{QuantizedKind, TensorHandle}; @@ -155,7 +152,7 @@ impl FusionBackend for JitBackend dtype: burn_tensor::DType, ) -> Self::Handle { fn cast( - tensor: JitTensor, + tensor: JitTensor, ) -> JitFusionHandle { JitFusionHandle::from(kernel::cast::(tensor)) } @@ -219,14 +216,14 @@ unsafe impl Send for JitFusionHandle {} unsafe impl Sync for JitFusionHandle {} impl JitFusionHandle { - pub(crate) fn into_tensor(self, shape: Shape) -> JitTensor { + pub(crate) fn into_tensor(self, shape: Shape) -> JitTensor { JitTensor { client: self.client, handle: self.handle, device: self.device, shape, strides: self.strides, - elem: PhantomData, + dtype: self.dtype, } } /// Return the reference to a tensor handle. @@ -255,14 +252,14 @@ impl JitFusionHandle { } } -impl From> for JitFusionHandle { - fn from(value: JitTensor) -> Self { +impl From> for JitFusionHandle { + fn from(value: JitTensor) -> Self { Self { client: value.client, handle: value.handle, device: value.device, strides: value.strides, - dtype: E::dtype(), + dtype: value.dtype, } } } diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d327b0a786..87d570c0ed 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -118,9 +118,9 @@ pub(crate) fn kernel_binop>( } pub(crate) fn launch_binop>( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { let ndims = lhs.shape.num_dims(); let vectorization_factor_lhs = tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1); @@ -152,8 +152,8 @@ pub(crate) fn launch_binop>( &client, cube_count, cube_dim, - lhs.as_tensor_arg(vectorization_factor), - rhs.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), + rhs.as_tensor_arg::(vectorization_factor), TensorArg::alias(0), None, false, @@ -166,8 +166,8 @@ pub(crate) fn launch_binop>( &client, cube_count, cube_dim, - lhs.as_tensor_arg(vectorization_factor), - rhs.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), + rhs.as_tensor_arg::(vectorization_factor), TensorArg::alias(1), None, rhs.strides != lhs.strides || rhs.shape != lhs.shape, @@ -184,9 +184,9 @@ pub(crate) fn launch_binop>( &client, cube_count, cube_dim, - lhs.as_tensor_arg(vectorization_factor), - rhs.as_tensor_arg(vectorization_factor), - output.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), + rhs.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), None, to_contiguous_lhs, to_contiguous_rhs, @@ -197,9 +197,9 @@ pub(crate) fn launch_binop>( } pub(crate) fn launch_scalar_binop>( - mut tensor: JitTensor, + mut tensor: JitTensor, scalar: E, -) -> JitTensor { +) -> JitTensor { if !tensor.is_contiguous_buffer() { tensor = into_contiguous(tensor); } @@ -220,14 +220,14 @@ pub(crate) fn launch_scalar_binop>( &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), ScalarArg::new(scalar), TensorArg::alias(0), ); tensor } else { - let output = empty_device( + let output = empty_device::( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), @@ -237,9 +237,9 @@ pub(crate) fn launch_scalar_binop>( &client, cube_count, CubeDim::default(), - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), ScalarArg::new(scalar), - output.as_tensor_arg(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), ); output diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 0516fbc195..798b79a0f0 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -30,11 +30,15 @@ pub(crate) fn cast_element( /// Cast a tensor to the given element type. /// /// Note: When input element is semantically a boolean, prefer bool_cast function. -pub fn cast( - input: JitTensor, -) -> JitTensor { +pub fn cast(input: JitTensor) -> JitTensor { if TypeId::of::() == TypeId::of::() { - return JitTensor::new_contiguous(input.client, input.device, input.shape, input.handle); + return JitTensor::new_contiguous( + input.client, + input.device, + input.shape, + input.handle, + input.dtype, + ); } // Vectorization is only enabled when the last dimension is contiguous. @@ -54,14 +58,15 @@ pub fn cast( input.device.clone(), input.shape.clone(), handle, + EO::dtype(), ); cast_element::launch::( &client, cube_count, cube_dim, - input.as_tensor_arg(vectorization_factor), - output.as_tensor_arg(vectorization_factor), + input.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), Some(rank as u32), ); diff --git a/crates/burn-jit/src/kernel/cast/bool_cast.rs b/crates/burn-jit/src/kernel/cast/bool_cast.rs index ddf33902aa..07a915ee1f 100644 --- a/crates/burn-jit/src/kernel/cast/bool_cast.rs +++ b/crates/burn-jit/src/kernel/cast/bool_cast.rs @@ -16,7 +16,7 @@ fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { /// where any non-zero value means true. Depending how it was created /// it may hold an uncanny bit combination. Naively casting it would not /// necessarily yield 0 or 1. -pub fn bool_cast(tensor: JitTensor) -> JitTensor { +pub fn bool_cast(tensor: JitTensor) -> JitTensor { let num_elems = tensor.shape.num_elements(); let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); let output = JitTensor::new_contiguous( @@ -24,6 +24,7 @@ pub fn bool_cast(tensor: JitTensor) -> Ji tensor.device.clone(), tensor.shape.clone(), buffer, + EO::dtype(), ); let cube_dim = CubeDim::default(); @@ -33,8 +34,8 @@ pub fn bool_cast(tensor: JitTensor) -> Ji &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), - output.as_tensor_arg(1), + tensor.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); output diff --git a/crates/burn-jit/src/kernel/clamp.rs b/crates/burn-jit/src/kernel/clamp.rs index 6b7c20c612..683e8aff8f 100644 --- a/crates/burn-jit/src/kernel/clamp.rs +++ b/crates/burn-jit/src/kernel/clamp.rs @@ -10,10 +10,10 @@ struct Options { } pub(crate) fn clamp( - input: JitTensor, + input: JitTensor, min_value: E, max_value: E, -) -> JitTensor { +) -> JitTensor { struct ClampOp; impl UnaryOp for ClampOp { diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index fd64df3bff..420a74d81b 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -1,5 +1,5 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; +use burn_tensor::{DType, Shape}; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, tensor_vectorization_factor, @@ -112,9 +112,9 @@ pub(crate) fn kernel_cmp>( } pub(crate) fn launch_cmp>( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { let ndims = lhs.shape.num_dims(); let vectorization_factor_lhs = tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1); @@ -147,31 +147,45 @@ pub(crate) fn launch_cmp>( &client, cube_count, cube_dim, - lhs.as_tensor_arg(vectorization_factor), - rhs.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), + rhs.as_tensor_arg::(vectorization_factor), TensorArg::alias(0), None, false, rhs.strides != lhs.strides || rhs.shape != lhs.shape, ); - JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides) + JitTensor::new( + lhs.client, + lhs.handle, + lhs.shape, + lhs.device, + lhs.strides, + DType::U32, + ) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { kernel_cmp::launch::( &client, cube_count, CubeDim::default(), - lhs.as_tensor_arg(vectorization_factor), - rhs.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), + rhs.as_tensor_arg::(vectorization_factor), TensorArg::alias(1), None, rhs.strides != lhs.strides || rhs.shape != lhs.shape, false, ); - JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides) + JitTensor::new( + rhs.client, + rhs.handle, + rhs.shape, + rhs.device, + rhs.strides, + DType::U32, + ) } else { - let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; @@ -179,9 +193,9 @@ pub(crate) fn launch_cmp>( &client, cube_count, CubeDim::default(), - lhs.as_tensor_arg(vectorization_factor), - rhs.as_tensor_arg(vectorization_factor), - output.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), + rhs.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), None, to_contiguous_lhs, to_contiguous_rhs, @@ -192,9 +206,9 @@ pub(crate) fn launch_cmp>( } pub(crate) fn launch_scalar_cmp>( - mut tensor: JitTensor, + mut tensor: JitTensor, scalar: E, -) -> JitTensor { +) -> JitTensor { if !tensor.is_contiguous_buffer() { tensor = into_contiguous(tensor); } @@ -216,7 +230,7 @@ pub(crate) fn launch_scalar_cmp &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), ScalarArg::new(scalar), TensorArg::alias(0), ); @@ -227,9 +241,10 @@ pub(crate) fn launch_scalar_cmp tensor.shape, tensor.device, tensor.strides, + DType::U32, ) } else { - let output = empty_device( + let output = empty_device::( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), @@ -239,75 +254,57 @@ pub(crate) fn launch_scalar_cmp &client, cube_count, CubeDim::default(), - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), ScalarArg::new(scalar), - output.as_tensor_arg(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), ); output } } -pub fn equal( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn equal(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_cmp::(lhs, rhs) } -pub fn greater( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn greater(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_cmp::(lhs, rhs) } pub fn greater_equal( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { launch_cmp::(lhs, rhs) } -pub fn lower( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn lower(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_cmp::(lhs, rhs) } pub fn lower_equal( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { launch_cmp::(lhs, rhs) } -pub fn equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn greater_elem( - lhs: JitTensor, - rhs: E, -) -> JitTensor { +pub fn greater_elem(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn lower_elem(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn lower_elem(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn greater_equal_elem( - lhs: JitTensor, - rhs: E, -) -> JitTensor { +pub fn greater_equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn lower_equal_elem( - lhs: JitTensor, - rhs: E, -) -> JitTensor { +pub fn lower_equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_cmp::(lhs, rhs) } diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 9764cef170..170f202e76 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -1,19 +1,22 @@ -use crate::{tensor::JitTensor, JitElement, JitRuntime}; +use crate::{execute_with_dtype, tensor::JitTensor, JitRuntime}; /// Make a jit tensor contiguous. -pub fn into_contiguous(tensor: JitTensor) -> JitTensor { +pub fn into_contiguous(tensor: JitTensor) -> JitTensor { if tensor.is_contiguous() { return tensor; } - let output = - cubecl::linalg::tensor::into_contiguous::(&tensor.client, tensor.as_handle_ref()); + execute_with_dtype!(tensor.dtype, E, { + let output = + cubecl::linalg::tensor::into_contiguous::(&tensor.client, tensor.as_handle_ref()); - JitTensor::new( - tensor.client, - output.handle, - output.shape.into(), - tensor.device, - output.strides, - ) + JitTensor::new( + tensor.client, + output.handle, + output.shape.into(), + tensor.device, + output.strides, + tensor.dtype, + ) + }) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 84396d5cd4..1796389157 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -70,12 +70,12 @@ impl Default for ConvTranspose2dStrategy { /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. /// pub fn conv2d( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvOptions<2>, strategy: Conv2dStrategy, -) -> JitTensor { +) -> JitTensor { match strategy { Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] @@ -96,12 +96,12 @@ pub fn conv2d( /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. /// pub fn conv_transpose2d( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvTransposeOptions<2>, strategy: ConvTranspose2dStrategy, -) -> JitTensor { +) -> JitTensor { match strategy { ConvTranspose2dStrategy::Direct => { conv_transpose2d_direct::(input, weight, bias, options) @@ -117,7 +117,7 @@ pub fn conv_transpose2d( } #[allow(unused)] -pub(crate) fn debug_data(tensor: JitTensor) -> TensorData { +pub(crate) fn debug_data(tensor: JitTensor) -> TensorData { let bytes = tensor.client.read_one(tensor.handle.binding()); TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 0de4275101..846aa3d8dd 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -21,11 +21,11 @@ use super::batches_per_run; /// * `options` - The options to use for the convolution /// pub fn conv_transpose2d_col2im( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> JitTensor { let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims(); let [batch_size, _, input_h, input_w] = input.shape.dims(); let groups = options.groups; @@ -70,7 +70,7 @@ pub fn conv_transpose2d_col2im( let runs = batch_size / batches_per_run; let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]); - let image = empty_device(input.client.clone(), input.device.clone(), im_shape); + let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]); let input = reshape(input, input_shape); @@ -95,7 +95,7 @@ pub fn conv_transpose2d_col2im( reshape(image, Shape::new([batch_size, im_channels, im_h, im_w])) } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); - let image = empty_device(input.client.clone(), input.device.clone(), im_shape); + let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); execute::( input, weight, @@ -111,10 +111,10 @@ pub fn conv_transpose2d_col2im( #[allow(clippy::too_many_arguments)] fn execute( - input: JitTensor, - weight: JitTensor, - bias: Option>, - image: JitTensor, + input: JitTensor, + weight: JitTensor, + bias: Option>, + image: JitTensor, options: ConvTransposeOptions<2>, kernel_h: usize, kernel_w: usize, @@ -131,16 +131,16 @@ fn execute( let columns = JitBackend::::float_matmul(weight, input); let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); - col2im( + col2im::( columns, bias, image, kernel_h, kernel_w, input_h, input_w, options, ); } #[allow(clippy::too_many_arguments)] fn col2im( - columns: JitTensor, - bias: Option>, - out: JitTensor, + columns: JitTensor, + bias: Option>, + out: JitTensor, kernel_h: usize, kernel_w: usize, out_h: usize, @@ -152,7 +152,7 @@ fn col2im( let columns = into_contiguous(columns); let has_bias = bias.is_some(); let bias = bias.map(into_contiguous).unwrap_or_else(|| { - empty_device( + empty_device::( columns.client.clone(), columns.device.clone(), Shape::new([1]), @@ -170,9 +170,9 @@ fn col2im( &columns.client, cube_count, cube_dim, - columns.as_tensor_arg(vectorization), - bias.as_tensor_arg(vectorization), - out.as_tensor_arg(vectorization), + columns.as_tensor_arg::(vectorization), + bias.as_tensor_arg::(vectorization), + out.as_tensor_arg::(vectorization), Col2ImArgsLaunch::new( ScalarArg::new(out_h as u32), ScalarArg::new(out_w as u32), diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index 57ee29650c..9a65b6ae51 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -122,11 +122,11 @@ fn direct_conv2d_kernel( /// #[allow(clippy::extra_unused_type_parameters)] pub fn conv2d_direct( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let [batch_size, _, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let channels_per_group = out_channels / options.groups; @@ -153,7 +153,7 @@ pub fn conv2d_direct( let weight = into_contiguous(weight); let shape_out = Shape::new([batch_size, out_channels, out_h, out_w]); - let output = empty_device( + let output = empty_device::( input.client.clone(), input.device.clone(), shape_out.clone(), @@ -166,7 +166,7 @@ pub fn conv2d_direct( } None => { let shape = Shape::from([output.shape.dims[0], 1, 1, 1]); - zeros_device(input.client.clone(), input.device.clone(), shape) + zeros_device::(input.client.clone(), input.device.clone(), shape) } }; @@ -178,10 +178,10 @@ pub fn conv2d_direct( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - weight.as_tensor_arg(1), - bias.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + weight.as_tensor_arg::(1), + bias.as_tensor_arg::(1), + output.as_tensor_arg::(1), Conv2dArgsLaunch::new( ScalarArg::new(options.stride[0] as u32), ScalarArg::new(options.stride[1] as u32), diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index bde1edb13b..88125f0463 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -115,20 +115,20 @@ pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> } fn im2col( - input: JitTensor, + input: JitTensor, options: ConvOptions<2>, kernel_h: usize, kernel_w: usize, out_h: usize, out_w: usize, -) -> JitTensor { +) -> JitTensor { let input = into_contiguous(input); let [batch_size, in_channels, _, _] = input.shape.dims(); let col_shape_0 = in_channels * kernel_h * kernel_w; let col_shape_1 = batch_size * out_h * out_w; let shape_col = Shape::new([col_shape_0, col_shape_1]); - let columns = empty_device( + let columns = empty_device::( input.client.clone(), input.device.clone(), shape_col.clone(), @@ -179,11 +179,11 @@ fn im2col( /// * `options` - The options to use for the convolution /// pub fn conv2d_im2col( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let [batch_size, in_channels, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -216,7 +216,7 @@ pub fn conv2d_im2col( let mut out = if batches_per_run != batch_size { let runs = batch_size / batches_per_run; let out_shape = Shape::new([runs, out_channels, batches_per_run, out_h, out_w]); - let out = empty_device(input.client.clone(), input.device.clone(), out_shape); + let out = empty_device::(input.client.clone(), input.device.clone(), out_shape); let in_shape = Shape::new([runs, batches_per_run, in_channels, in_height, in_width]); let input = reshape(input, in_shape); let in_shape_run = Shape::new([batches_per_run, in_channels, in_height, in_width]); @@ -225,7 +225,7 @@ pub fn conv2d_im2col( let input = reshape(input, in_shape_run.clone()); let out_slice = JitBackend::::float_narrow(out.clone(), 0, run, 1); let out_slice = reshape(out_slice, matmul_shape.clone()); - execute( + execute::( input, weight.clone(), out_slice, @@ -237,8 +237,8 @@ pub fn conv2d_im2col( let out = swap_dims(out, 1, 2); reshape(out, Shape::new([batch_size, out_channels, out_h, out_w])) } else { - let out = empty_device(input.client.clone(), input.device.clone(), matmul_shape); - execute(input, weight, out.clone(), options, out_h, out_w); + let out = empty_device::(input.client.clone(), input.device.clone(), matmul_shape); + execute::(input, weight, out.clone(), options, out_h, out_w); let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); swap_dims(out, 0, 1) }; @@ -251,11 +251,11 @@ pub fn conv2d_im2col( } fn execute_1x1_kernel( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let [batch_size, _, height, width] = input.shape.dims(); let [out_channels, in_c_per_grp, _, _] = weight.shape.dims(); let groups = options.groups; @@ -278,9 +278,9 @@ fn execute_1x1_kernel( } fn execute( - input: JitTensor, - weight: JitTensor, - out: JitTensor, + input: JitTensor, + weight: JitTensor, + out: JitTensor, options: ConvOptions<2>, out_h: usize, out_w: usize, @@ -289,7 +289,7 @@ fn execute( let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; - let columns = im2col(input, options.clone(), kernel_h, kernel_w, out_h, out_w); + let columns = im2col::(input, options.clone(), kernel_h, kernel_w, out_h, out_w); let [col_shape_0, col_shape_1] = columns.shape.dims(); let col_shape_0 = col_shape_0 / groups; let out_c_per_group = out_channels / groups; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index ae74b4129d..fd5d3857f2 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -30,11 +30,11 @@ use crate::{ /// #[allow(clippy::extra_unused_type_parameters)] pub fn conv2d_implicit_gemm( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let is_tf32 = F::as_elem() == Elem::Float(FloatKind::F32) && input .client @@ -88,7 +88,7 @@ pub fn conv2d_implicit_gemm( let weight = into_contiguous(permute(weight, &[2, 3, 1, 0])); let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]); - let out = empty_device(input.client.clone(), input.device.clone(), out_shape); + let out = empty_device::(input.client.clone(), input.device.clone(), out_shape); // Implicit GEMM matrix size let gemm_m = (padded_batch_size * out_h * out_w) as u32; @@ -124,11 +124,11 @@ pub fn conv2d_implicit_gemm( Some(bias) if out_channels == padded_out_channels => bias, Some(bias) => { let shape = Shape::new([padded_out_channels]); - let padded_bias = zeros_device(bias.client.clone(), bias.device.clone(), shape); + let padded_bias = zeros_device::(bias.client.clone(), bias.device.clone(), shape); #[allow(clippy::single_range_in_vec_init)] - slice_assign(padded_bias, &[0..out_channels], bias) + slice_assign::(padded_bias, &[0..out_channels], bias) } - None => empty_device(input.client.clone(), input.device.clone(), Shape::new([1])), + None => empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])), }; let settings = GemmSettings { @@ -169,10 +169,10 @@ pub fn conv2d_implicit_gemm( &input.client, cube_count, cube_dim, - input.as_tensor_arg(input_vectorization), - weight.as_tensor_arg(weight_vectorization), - bias.as_tensor_arg(1), - out.as_tensor_arg(1), + input.as_tensor_arg::(input_vectorization), + weight.as_tensor_arg::(weight_vectorization), + bias.as_tensor_arg::(1), + out.as_tensor_arg::(1), DimensionsLaunch::new( ScalarArg::new(gemm_m), ScalarArg::new(gemm_n), @@ -202,7 +202,7 @@ pub fn conv2d_implicit_gemm( }, ); - let out = slice(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); + let out = slice::(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); // Reset to NCHW permute(out, &[0, 3, 1, 2]) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index dd21c84969..1062241d75 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -123,11 +123,11 @@ fn conv_transpose2d_direct_kernel( /// #[allow(clippy::extra_unused_type_parameters)] pub fn conv_transpose2d_direct( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> JitTensor { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_height, in_width] = input.shape.dims(); @@ -146,7 +146,7 @@ pub fn conv_transpose2d_direct( let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); - let output = empty_device( + let output = empty_device::( input.client.clone(), input.device.clone(), shape_out.clone(), @@ -159,7 +159,7 @@ pub fn conv_transpose2d_direct( } None => { let shape = Shape::from([output.shape.dims[0], 1, 1, 1]); - zeros_device(input.client.clone(), input.device.clone(), shape) + zeros_device::(input.client.clone(), input.device.clone(), shape) } }; @@ -170,10 +170,10 @@ pub fn conv_transpose2d_direct( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - weight.as_tensor_arg(1), - bias.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + weight.as_tensor_arg::(1), + bias.as_tensor_arg::(1), + output.as_tensor_arg::(1), ConvArgsLaunch::new( ScalarArg::new(options.stride[0] as u32), ScalarArg::new(options.stride[1] as u32), diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 143d398c43..05ec7fd960 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -23,11 +23,11 @@ use super::Conv2dAutotuneKey; /// Executes autotune on conv2d operations pub fn conv2d_autotune( - input: JitTensor, - weights: JitTensor, - bias: Option>, + input: JitTensor, + weights: JitTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let client = input.client.clone(); static TUNER: LocalTuner = local_tuner!(); @@ -43,16 +43,16 @@ pub fn conv2d_autotune( #[tune( operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm), - create_key = create_key, + create_key = create_key::, should_run = should_run )] pub fn conv2d_operations( key: JitAutotuneKey, - input: JitTensor, - weights: JitTensor, - bias: Option>, + input: JitTensor, + weights: JitTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let device = &input.device; let key = match key { JitAutotuneKey::Conv2d(key) => key, @@ -118,9 +118,9 @@ fn should_run( } fn create_key( - input: &JitTensor, - weights: &JitTensor, - bias: &Option>, + input: &JitTensor, + weights: &JitTensor, + bias: &Option>, options: &ConvOptions<2>, ) -> JitAutotuneKey { let [batch_size, in_channels, height, width] = input.shape.dims(); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index aa0f0972ad..3a8c1d04f2 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -17,11 +17,11 @@ use super::ConvTranspose2dAutotuneKey; /// Executes autotune on conv2d operations pub fn conv_transpose2d_autotune( - input: JitTensor, - weights: JitTensor, - bias: Option>, + input: JitTensor, + weights: JitTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> JitTensor { let client = input.client.clone(); static TUNER: LocalTuner = local_tuner!(); @@ -35,14 +35,14 @@ pub fn conv_transpose2d_autotune( ) } -#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key, should_run = should_run)] +#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key::, should_run = should_run)] pub fn conv_transpose2d_operations( key: JitAutotuneKey, - input: JitTensor, - weights: JitTensor, - bias: Option>, + input: JitTensor, + weights: JitTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> JitTensor { let key = match key { JitAutotuneKey::ConvTranspose2d(key) => key, _ => unreachable!(), @@ -64,9 +64,9 @@ pub fn conv_transpose2d_operations( - input: &JitTensor, - weights: &JitTensor, - bias: &Option>, + input: &JitTensor, + weights: &JitTensor, + bias: &Option>, options: &ConvTransposeOptions<2>, ) -> JitAutotuneKey { let [batch_size, in_channels, height, width] = input.shape.dims(); diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index 7bf0d7e280..157610794b 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -140,11 +140,11 @@ fn conv3d_kernel( } pub(crate) fn conv3d( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvOptions<3>, -) -> JitTensor { +) -> JitTensor { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_depth, in_height, in_width] = input.shape.dims(); @@ -174,7 +174,7 @@ pub(crate) fn conv3d( let shape_out = Shape::new([batch_size, out_channels, out_0, out_1, out_2]); - let output = empty_device( + let output = empty_device::( input.client.clone(), input.device.clone(), shape_out.clone(), @@ -187,7 +187,7 @@ pub(crate) fn conv3d( } None => { let shape = Shape::from([output.shape.dims[0], 1, 1, 1, 1]); - zeros_device(input.client.clone(), input.device.clone(), shape) + zeros_device::(input.client.clone(), input.device.clone(), shape) } }; @@ -198,10 +198,10 @@ pub(crate) fn conv3d( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - weight.as_tensor_arg(1), - bias.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + weight.as_tensor_arg::(1), + bias.as_tensor_arg::(1), + output.as_tensor_arg::(1), Conv3dArgsLaunch::new( ScalarArg::new(options.stride[0] as u32), ScalarArg::new(options.stride[1] as u32), diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs index 8633a06048..860b14ae6a 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs @@ -146,11 +146,11 @@ fn conv_transpose3d_kernel( } pub(crate) fn conv_transpose3d( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: JitTensor, + weight: JitTensor, + bias: Option>, options: ConvTransposeOptions<3>, -) -> JitTensor { +) -> JitTensor { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_depth, in_height, in_width] = input.shape.dims(); @@ -180,7 +180,7 @@ pub(crate) fn conv_transpose3d( out_2, ]); - let output = empty_device( + let output = empty_device::( input.client.clone(), input.device.clone(), shape_out.clone(), @@ -193,7 +193,7 @@ pub(crate) fn conv_transpose3d( } None => { let shape = Shape::from([output.shape.dims[0], 1, 1, 1, 1]); - zeros_device(input.client.clone(), input.device.clone(), shape) + zeros_device::(input.client.clone(), input.device.clone(), shape) } }; @@ -204,10 +204,10 @@ pub(crate) fn conv_transpose3d( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - weight.as_tensor_arg(1), - bias.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + weight.as_tensor_arg::(1), + bias.as_tensor_arg::(1), + output.as_tensor_arg::(1), ConvArgsLaunch::new( ScalarArg::new(options.stride[0] as u32), ScalarArg::new(options.stride[1] as u32), diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index 30391276fb..b005a2384c 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -183,13 +183,13 @@ pub(crate) fn bilinear_interpolate( } pub(crate) fn deform_im2col( - input: JitTensor, - offset: JitTensor, - mask: Option>, + input: JitTensor, + offset: JitTensor, + mask: Option>, options: DeformConvOptions<2>, out_dims: (usize, usize), kernel_dims: (usize, usize), -) -> JitTensor { +) -> JitTensor { let client = input.client.clone(); let device = input.device.clone(); @@ -202,10 +202,10 @@ pub(crate) fn deform_im2col( batch_size * out_height * out_width, ]); - let output = zeros_device(client.clone(), device.clone(), shape_out.clone()); + let output = zeros_device::(client.clone(), device.clone(), shape_out.clone()); let use_mask = mask.is_some(); let mask = mask.unwrap_or_else(|| { - ones_device( + ones_device::( client.clone(), device.clone(), Shape::new([ @@ -252,13 +252,13 @@ pub(crate) fn deform_im2col( } pub(crate) fn deform_conv2d( - input: JitTensor, - offset: JitTensor, - weight: JitTensor, - mask: Option>, - bias: Option>, + input: JitTensor, + offset: JitTensor, + weight: JitTensor, + mask: Option>, + bias: Option>, options: DeformConvOptions<2>, -) -> JitTensor { +) -> JitTensor { let input = into_contiguous(input); let offset = into_contiguous(offset); let weight = into_contiguous(weight); @@ -285,7 +285,8 @@ pub(crate) fn deform_conv2d( ); let out_dims = (out_h, out_w); - let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w)); + let columns = + deform_im2col::(input, offset, mask, options, out_dims, (kernel_h, kernel_w)); let [col_size_0, col_size_1] = columns.shape.dims(); let col_size_0 = col_size_0 / groups; @@ -307,9 +308,9 @@ pub(crate) fn deform_conv2d( } pub(crate) fn index( - tensor: JitTensor, + tensor: JitTensor, index: usize, -) -> JitTensor { +) -> JitTensor { let [_, shape_0, shape_1] = tensor.shape.dims(); let tensor = JitBackend::::float_narrow(tensor, 0, index, 1); reshape(tensor, Shape::new([shape_0, shape_1])) diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 89080fecfb..4022a0bbe2 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -19,12 +19,12 @@ use super::{bilinear_interpolate, deform_im2col, index}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] pub(crate) fn deform_conv2d_backward( - input: JitTensor, - offset: JitTensor, - weight: JitTensor, - mask: Option>, - bias: Option>, - out_grad: JitTensor, + input: JitTensor, + offset: JitTensor, + weight: JitTensor, + mask: Option>, + bias: Option>, + out_grad: JitTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward> { let [_, _, out_h, out_w] = out_grad.shape.dims(); @@ -72,14 +72,14 @@ pub(crate) fn deform_conv2d_backward( - input: JitTensor, - offset: JitTensor, - mask: Option>, - out_grad: JitTensor, + input: JitTensor, + offset: JitTensor, + mask: Option>, + out_grad: JitTensor, options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), -) -> JitTensor { +) -> JitTensor { let [_, in_channels, _, _] = input.shape.dims(); let [_, out_channels, _, _] = out_grad.shape.dims(); let (kernel_h, kernel_w) = kernel_dims; @@ -88,7 +88,7 @@ fn compute_weight_grad( let in_c_per_group = in_channels / groups; let out_c_per_group = out_channels / groups; - let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims); + let columns = deform_im2col::(input, offset, mask, options, out_dims, kernel_dims); let [col_size_0, col_size_1] = columns.shape.dims(); let col_size_0 = col_size_0 / groups; @@ -106,17 +106,17 @@ fn compute_weight_grad( ) } -type InputGradients = (JitTensor, JitTensor, Option>); +type InputGradients = (JitTensor, JitTensor, Option>); fn backward_gradient_inputs( - image: JitTensor, - weight: JitTensor, - offset: JitTensor, - mask: Option>, - out_grad: JitTensor, + image: JitTensor, + weight: JitTensor, + offset: JitTensor, + mask: Option>, + out_grad: JitTensor, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> InputGradients { +) -> InputGradients { let client = out_grad.client.clone(); let device = out_grad.device.clone(); @@ -129,7 +129,7 @@ fn backward_gradient_inputs( let col_shape_0 = in_c_per_group * kernel_h * kernel_w; let col_shape_1 = batch_size * out_h * out_w; let col_shape = Shape::new([groups, col_shape_0, col_shape_1]); - let mut columns = empty_device(client, device, col_shape); + let mut columns = empty_device::(client, device, col_shape); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); @@ -168,13 +168,13 @@ fn backward_gradient_inputs( } fn compute_offset_and_mask_gradient( - columns: JitTensor, - image: JitTensor, - offset: JitTensor, - mask: Option>, + columns: JitTensor, + image: JitTensor, + offset: JitTensor, + mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> (JitTensor, Option>) { +) -> (JitTensor, Option>) { let client = offset.client.clone(); let device = offset.device.clone(); let (kernel_height, kernel_width) = kernel_dims; @@ -182,7 +182,7 @@ fn compute_offset_and_mask_gradient( let use_mask = mask.is_some(); let mask = mask.unwrap_or_else(|| { - ones_device( + ones_device::( client.clone(), device.clone(), Shape::new([ @@ -194,8 +194,8 @@ fn compute_offset_and_mask_gradient( ) }); - let grad_offset = empty_device(client.clone(), device.clone(), offset.shape.clone()); - let grad_mask = empty_device(client.clone(), device.clone(), mask.shape.clone()); + let grad_offset = empty_device::(client.clone(), device.clone(), offset.shape.clone()); + let grad_mask = empty_device::(client.clone(), device.clone(), mask.shape.clone()); let num_elements_offset = offset.shape.num_elements(); let cube_dim = CubeDim::default(); @@ -413,13 +413,13 @@ fn get_coordinate_weight( } fn compute_input_grad( - columns: JitTensor, - offset: JitTensor, - mask: Option>, + columns: JitTensor, + offset: JitTensor, + mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), input_shape: Shape, -) -> JitTensor { +) -> JitTensor { let client = offset.client.clone(); let device = offset.device.clone(); @@ -434,8 +434,9 @@ fn compute_input_grad( ); let use_mask = mask.is_some(); - let mask = mask - .unwrap_or_else(|| ones_device(client.clone(), device.clone(), Shape::new([1, 1, 1, 1]))); + let mask = mask.unwrap_or_else(|| { + ones_device::(client.clone(), device.clone(), Shape::new([1, 1, 1, 1])) + }); let num_elements = columns.shape.num_elements(); let cube_dim = CubeDim::default(); @@ -445,10 +446,10 @@ fn compute_input_grad( &offset.client, cube_count, cube_dim, - offset.as_tensor_arg(1), - mask.as_tensor_arg(1), - columns.as_tensor_arg(1), - grad_in.as_tensor_arg(1), + offset.as_tensor_arg::(1), + mask.as_tensor_arg::(1), + columns.as_tensor_arg::(1), + grad_in.as_tensor_arg::(1), DeformConv2dCol2ImgArgsLaunch::new( ScalarArg::new(options.stride[0] as u32), ScalarArg::new(options.stride[1] as u32), @@ -467,7 +468,7 @@ fn compute_input_grad( use_mask, ); - cast(grad_in) + cast::(grad_in) } #[derive(CubeLaunch)] diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 7fb1f0a3c8..e35cac8b2c 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -32,22 +32,22 @@ fn flip_kernel( } pub(crate) fn flip( - tensor: JitTensor, + tensor: JitTensor, indices: &[usize], -) -> JitTensor { - let output = empty_device( +) -> JitTensor { + let output = empty_device::( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), ); - flip_on_output(tensor, output, indices) + flip_on_output::(tensor, output, indices) } pub(crate) fn flip_on_output( - tensor: JitTensor, - output: JitTensor, + tensor: JitTensor, + output: JitTensor, indices: &[usize], -) -> JitTensor { +) -> JitTensor { let ndims = tensor.shape.num_dims(); let mut indices_sequence = SequenceArg::<'_, R, u32>::new(); @@ -63,8 +63,8 @@ pub(crate) fn flip_on_output( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), - output.as_tensor_arg(1), + tensor.as_tensor_arg::(1), + output.as_tensor_arg::(1), indices_sequence, ndims as u32, ); diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 14e36db94d..9e9b5685bb 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -34,12 +34,12 @@ fn gather_kernel( pub(crate) fn gather( dim: usize, - tensor: JitTensor, - indices: JitTensor, -) -> JitTensor { + tensor: JitTensor, + indices: JitTensor, +) -> JitTensor { let shape_output = indices.shape.clone(); let total_elem = shape_output.num_elements(); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + let output = empty_device::(tensor.client.clone(), tensor.device.clone(), shape_output); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); @@ -48,9 +48,9 @@ pub(crate) fn gather( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), - indices.as_tensor_arg(1), - output.as_tensor_arg(1), + tensor.as_tensor_arg::(1), + indices.as_tensor_arg::(1), + output.as_tensor_arg::(1), ScalarArg::new(dim as u32), ) } diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 059635e662..3887bfbd8b 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -20,15 +20,15 @@ fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor } pub(crate) fn repeat_dim( - input: JitTensor, + input: JitTensor, dim: usize, times: usize, -) -> JitTensor { +) -> JitTensor { let mut shape = input.shape.clone(); // Create output handle shape.dims[dim] *= times; - let output = empty_device(input.client.clone(), input.device.clone(), shape); + let output = empty_device::(input.client.clone(), input.device.clone(), shape); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -38,8 +38,8 @@ pub(crate) fn repeat_dim( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), ScalarArg::new(dim as u32), ) }; diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index a6dd72c022..4ddd9c00fb 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -67,10 +67,10 @@ fn scatter_kernel( pub(crate) fn scatter( dim: usize, - tensor: JitTensor, - indices: JitTensor, - value: JitTensor, -) -> JitTensor { + tensor: JitTensor, + indices: JitTensor, + value: JitTensor, +) -> JitTensor { let ndims = tensor.shape.num_dims(); let mut indices = kernel::into_contiguous(indices); let tensor = kernel::into_contiguous(tensor); @@ -109,9 +109,9 @@ pub(crate) fn scatter( &indices.client.clone(), cube_count, cube_dim, - tensor.as_tensor_arg(1), - indices.as_tensor_arg(1), - value.as_tensor_arg(1), + tensor.as_tensor_arg::(1), + indices.as_tensor_arg::(1), + value.as_tensor_arg::(1), ScalarArg::new(dim as u32), ) } diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 545bbe6f01..b104bf504f 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -29,16 +29,16 @@ fn select_kernel( } pub(crate) fn select( - tensor: JitTensor, + tensor: JitTensor, dim: usize, - indices: JitTensor, -) -> JitTensor { + indices: JitTensor, +) -> JitTensor { let ndims = tensor.shape.num_dims(); let mut shape_output = tensor.shape.clone(); shape_output.dims[dim] = indices.shape.dims[0]; let total_elem = shape_output.num_elements(); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + let output = empty_device::(tensor.client.clone(), tensor.device.clone(), shape_output); let dummy_array = vec![1; ndims]; let cube_dim = CubeDim::default(); @@ -49,10 +49,10 @@ pub(crate) fn select( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), + tensor.as_tensor_arg::(1), // Ignore shape and stride TensorArg::from_raw_parts::(&indices.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg(1), + output.as_tensor_arg::(1), ScalarArg::new(dim as u32), ) }; diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 37a39a6331..a0fed49dbd 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -45,11 +45,11 @@ fn select_assign_kernel( } pub(crate) fn select_assign( - tensor: JitTensor, + tensor: JitTensor, dim: usize, - indices: JitTensor, - value: JitTensor, -) -> JitTensor { + indices: JitTensor, + value: JitTensor, +) -> JitTensor { let ndims = tensor.shape.num_dims(); let tensor = match tensor.can_mut() { true => tensor, @@ -80,10 +80,10 @@ pub(crate) fn select_assign( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), + tensor.as_tensor_arg::(1), // Ignored shape + custom strides. TensorArg::from_raw_parts::(&indices.handle, &strides, &strides, 1), - value.as_tensor_arg(1), + value.as_tensor_arg::(1), ScalarArg::new(dim as u32), ); }; diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index fa0a82973c..7f20f033b8 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -4,9 +4,9 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use std::ops::Range; pub(crate) fn slice( - tensor: JitTensor, + tensor: JitTensor, indices: &[Range], -) -> JitTensor { +) -> JitTensor { let mut dims = tensor.shape.dims.clone(); let mut offset_start = 0u64; let mut offset_end = 0u64; @@ -34,11 +34,13 @@ pub(crate) fn slice( Shape::from(dims), tensor.device, tensor.strides, + tensor.dtype, ) } else { let shape_output = Shape::from(dims); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - slice_on_output(tensor, output, indices) + let output = + empty_device::(tensor.client.clone(), tensor.device.clone(), shape_output); + slice_on_output::(tensor, output, indices) } } @@ -67,10 +69,10 @@ fn slice_kernel( } pub(crate) fn slice_on_output( - tensor: JitTensor, - output: JitTensor, + tensor: JitTensor, + output: JitTensor, indices: &[Range], -) -> JitTensor { +) -> JitTensor { let ndims = tensor.shape.num_dims(); let mut indices_sequence = SequenceArg::::new(); @@ -87,8 +89,8 @@ pub(crate) fn slice_on_output( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), - output.as_tensor_arg(1), + tensor.as_tensor_arg::(1), + output.as_tensor_arg::(1), indices_sequence, ndims as u32, ) diff --git a/crates/burn-jit/src/kernel/index/slice_assign.rs b/crates/burn-jit/src/kernel/index/slice_assign.rs index 29ddde64c0..ca3c9adf6e 100644 --- a/crates/burn-jit/src/kernel/index/slice_assign.rs +++ b/crates/burn-jit/src/kernel/index/slice_assign.rs @@ -26,10 +26,10 @@ fn slice_assign_kernel( } pub(crate) fn slice_assign( - tensor: JitTensor, + tensor: JitTensor, indices: &[Range], - value: JitTensor, -) -> JitTensor { + value: JitTensor, +) -> JitTensor { let tensor = match tensor.can_mut() { true => tensor, false => tensor.copy(), @@ -49,8 +49,8 @@ pub(crate) fn slice_assign( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg(1), - value.as_tensor_arg(1), + tensor.as_tensor_arg::(1), + value.as_tensor_arg::(1), indices_sequence, ndims as u32, ); diff --git a/crates/burn-jit/src/kernel/interpolate/base.rs b/crates/burn-jit/src/kernel/interpolate/base.rs index 4cebb6602e..c3d3a51b21 100644 --- a/crates/burn-jit/src/kernel/interpolate/base.rs +++ b/crates/burn-jit/src/kernel/interpolate/base.rs @@ -16,21 +16,21 @@ use super::{ /// /// Supports nearest, bilinear and bicubic modes pub fn interpolate( - input: JitTensor, + input: JitTensor, output_size: [usize; 2], options: InterpolateOptions, -) -> JitTensor { +) -> JitTensor { let input = into_contiguous(input); let [batch_size, channels, _, _] = input.shape.dims(); let [out_height, out_width] = output_size; let shape_out = Shape::new([batch_size, channels, out_height, out_width]); - let output = empty_device(input.client.clone(), input.device.clone(), shape_out); + let output = empty_device::(input.client.clone(), input.device.clone(), shape_out); match options.mode { - InterpolateMode::Nearest => interpolate_nearest_launch(input, output), - InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output), - InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output), + InterpolateMode::Nearest => interpolate_nearest_launch::(input, output), + InterpolateMode::Bilinear => interpolate_bilinear_launch::(input, output), + InterpolateMode::Bicubic => interpolate_bicubic_launch::(input, output), } } @@ -38,11 +38,11 @@ pub fn interpolate( /// /// Note: only nearest mode is supported pub fn interpolate_backward( - input: JitTensor, - out_grad: JitTensor, + input: JitTensor, + out_grad: JitTensor, _output_size: [usize; 2], options: InterpolateOptions, -) -> JitTensor { +) -> JitTensor { let out_grad = into_contiguous(out_grad); let output_shape = input.shape.clone(); let num_elems = input.shape.num_elements(); @@ -52,10 +52,11 @@ pub fn interpolate_backward( input.device.clone(), output_shape, buffer, + input.dtype, ); match options.mode { - InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output), + InterpolateMode::Nearest => interpolate_nearest_backward_launch::(out_grad, output), InterpolateMode::Bilinear => { panic!("bilinear interpolation backward is not supported by JIT backend") } diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 2e554bf647..1d545d79c7 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -122,9 +122,9 @@ fn cubic_convolution_2(x: F, a: F) -> F { } pub(crate) fn interpolate_bicubic_launch( - input: JitTensor, - output: JitTensor, -) -> JitTensor { + input: JitTensor, + output: JitTensor, +) -> JitTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -132,8 +132,8 @@ pub(crate) fn interpolate_bicubic_launch( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); output diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 0314c77544..3557fcdbb8 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -80,9 +80,9 @@ fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor< } pub(crate) fn interpolate_bilinear_launch( - input: JitTensor, - output: JitTensor, -) -> JitTensor { + input: JitTensor, + output: JitTensor, +) -> JitTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -90,8 +90,8 @@ pub(crate) fn interpolate_bilinear_launch( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); output diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 5aa671319e..0743a13567 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -32,9 +32,9 @@ fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor( - input: JitTensor, - output: JitTensor, -) -> JitTensor { + input: JitTensor, + output: JitTensor, +) -> JitTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -43,8 +43,8 @@ pub(crate) fn interpolate_nearest_launch( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), ) }; diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 7a90958177..5ea860a7ae 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -56,9 +56,9 @@ fn end_index(input_index: u32, output_size: u32, input_size: u32) -> u } pub(crate) fn interpolate_nearest_backward_launch( - out_grad: JitTensor, - output: JitTensor, -) -> JitTensor { + out_grad: JitTensor, + output: JitTensor, +) -> JitTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -67,8 +67,8 @@ pub(crate) fn interpolate_nearest_backward_launch(1), + output.as_tensor_arg::(1), ) }; diff --git a/crates/burn-jit/src/kernel/mask/base.rs b/crates/burn-jit/src/kernel/mask/base.rs index 45af236252..2140972326 100644 --- a/crates/burn-jit/src/kernel/mask/base.rs +++ b/crates/burn-jit/src/kernel/mask/base.rs @@ -3,10 +3,10 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; /// Execute the mask fill kernel. pub(crate) fn mask_fill_auto( - tensor: JitTensor, - mask: JitTensor, + tensor: JitTensor, + mask: JitTensor, value: E, -) -> JitTensor { +) -> JitTensor { let strategy = if tensor.can_mut() { MaskFillStrategy::Inplace } else { @@ -18,10 +18,10 @@ pub(crate) fn mask_fill_auto( /// Execute the mask where kernel. pub(crate) fn mask_where_auto( - tensor: JitTensor, - mask: JitTensor, - value: JitTensor, -) -> JitTensor { + tensor: JitTensor, + mask: JitTensor, + value: JitTensor, +) -> JitTensor { let strategy = if tensor.can_mut_broadcast(&value) { MaskWhereStrategy::InplaceLhs } else if value.can_mut_broadcast(&tensor) { @@ -30,5 +30,5 @@ pub(crate) fn mask_where_auto( MaskWhereStrategy::Readonly }; - super::mask_where(tensor, mask, value, strategy) + super::mask_where::(tensor, mask, value, strategy) } diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index a8efb8fb87..e8b3f814d9 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -57,24 +57,24 @@ pub enum MaskFillStrategy { /// Execute the mask fill kernel with the given strategy. pub fn mask_fill( - input: JitTensor, - mask: JitTensor, + input: JitTensor, + mask: JitTensor, value: E, strategy: MaskFillStrategy, -) -> JitTensor { +) -> JitTensor { match strategy { - MaskFillStrategy::Readonly => mask_fill_readonly(input, mask, value), - MaskFillStrategy::Inplace => mask_fill_inplace(input, mask, value), + MaskFillStrategy::Readonly => mask_fill_readonly::(input, mask, value), + MaskFillStrategy::Inplace => mask_fill_inplace::(input, mask, value), } } fn mask_fill_readonly( - input: JitTensor, - mask: JitTensor, + input: JitTensor, + mask: JitTensor, value: EI, -) -> JitTensor { +) -> JitTensor { let ndims = input.shape.num_dims(); - let output = empty_device( + let output = empty_device::( input.client.clone(), input.device.clone(), input.shape.clone(), @@ -87,9 +87,9 @@ fn mask_fill_readonly( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - mask.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + mask.as_tensor_arg::(1), + output.as_tensor_arg::(1), ScalarArg::new(value), ndims as u32, ); @@ -98,10 +98,10 @@ fn mask_fill_readonly( } fn mask_fill_inplace( - input: JitTensor, - mask: JitTensor, + input: JitTensor, + mask: JitTensor, value: EI, -) -> JitTensor { +) -> JitTensor { let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); @@ -110,8 +110,8 @@ fn mask_fill_inplace( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - mask.as_tensor_arg(1), + input.as_tensor_arg::(1), + mask.as_tensor_arg::(1), ScalarArg::new(value), ndims as u32, ); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index fadad4ad7a..73c7c8fcf1 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -62,25 +62,25 @@ pub enum MaskWhereStrategy { /// Execute the mask where kernel with the given strategy. pub fn mask_where( - input: JitTensor, - mask: JitTensor, - value: JitTensor, + input: JitTensor, + mask: JitTensor, + value: JitTensor, strategy: MaskWhereStrategy, -) -> JitTensor { +) -> JitTensor { match strategy { - MaskWhereStrategy::Readonly => mask_where_readonly(input, mask, value), - MaskWhereStrategy::InplaceLhs => mask_where_inplace(input, mask, value, false), - MaskWhereStrategy::InplaceRhs => mask_where_inplace(value, mask, input, true), + MaskWhereStrategy::Readonly => mask_where_readonly::(input, mask, value), + MaskWhereStrategy::InplaceLhs => mask_where_inplace::(input, mask, value, false), + MaskWhereStrategy::InplaceRhs => mask_where_inplace::(value, mask, input, true), } } fn mask_where_readonly( - input: JitTensor, - mask: JitTensor, - value: JitTensor, -) -> JitTensor { + input: JitTensor, + mask: JitTensor, + value: JitTensor, +) -> JitTensor { let ndims = input.shape.num_dims(); - let output = empty_device( + let output = empty_device::( input.client.clone(), input.device.clone(), input.shape.clone(), @@ -93,10 +93,10 @@ fn mask_where_readonly( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - mask.as_tensor_arg(1), - value.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + mask.as_tensor_arg::(1), + value.as_tensor_arg::(1), + output.as_tensor_arg::(1), ndims as u32, ); @@ -104,11 +104,11 @@ fn mask_where_readonly( } fn mask_where_inplace( - input: JitTensor, - mask: JitTensor, - value: JitTensor, + input: JitTensor, + mask: JitTensor, + value: JitTensor, reverse: bool, -) -> JitTensor { +) -> JitTensor { let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); @@ -117,9 +117,9 @@ fn mask_where_inplace( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - mask.as_tensor_arg(1), - value.as_tensor_arg(1), + input.as_tensor_arg::(1), + mask.as_tensor_arg::(1), + value.as_tensor_arg::(1), ScalarArg::new(reverse as u32), ndims as u32, ); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index f5c6cec830..562197647f 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -35,14 +35,14 @@ impl Default for MatmulStrategy { /// Launch a matmul kernel using the given strategy. pub fn matmul( - lhs: JitTensor, - rhs: JitTensor, + lhs: JitTensor, + rhs: JitTensor, strategy: MatmulStrategy, -) -> JitTensor { +) -> JitTensor { match strategy { MatmulStrategy::Simple { grid_x, grid_y } => { - let out = init_matmul_output(&lhs, &rhs); - matmul_simple(lhs, rhs, out, grid_x, grid_y) + let out = init_matmul_output::(&lhs, &rhs); + matmul_simple::(lhs, rhs, out, grid_x, grid_y) } MatmulStrategy::Cube => { let out = init_matmul_output::(&lhs, &rhs); @@ -56,7 +56,7 @@ pub fn matmul( out } #[cfg(feature = "autotune")] - MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), + MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs), } } diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index 06fa1a58f0..7d75b30395 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -84,21 +84,21 @@ fn matmul_kernel( /// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16 pub fn matmul_mem_coalescing_default( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, +) -> JitTensor { matmul_simple::(lhs, rhs, out, PLANE_DIM_APPROX, PLANE_DIM_APPROX) } /// Matrix multiplication using memory coalescing algorithm with custom cube dimensions pub fn matmul_simple( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, cube_dim_x: usize, cube_dim_y: usize, -) -> JitTensor { +) -> JitTensor { lhs.assert_is_on_same_device(&rhs); let ndims = lhs.shape.num_dims(); let lhs = into_contiguous(lhs); @@ -127,14 +127,14 @@ pub fn matmul_simple( &lhs.client, cube_count, CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), - lhs.as_tensor_arg(vectorization_factor), + lhs.as_tensor_arg::(vectorization_factor), TensorArg::from_raw_parts::( &rhs.handle, &rhs.strides, &rhs_original_shape.dims, // We need the original shape. vectorization_factor, ), - out.as_tensor_arg(1), + out.as_tensor_arg::(1), Some(ndims as u32 - 2), ); }; diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index b5804f9c36..38df9f7fe1 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use burn_tensor::{Element, ElementConversion}; use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}; @@ -16,17 +18,19 @@ use super::key::MatmulAutotuneKey; /// Autotune key is given by concatenating the closest upper power of 2 of m, k and n pub struct MatmulAutotuneOperationSet { key: JitAutotuneKey, - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, + _e: PhantomData, } impl MatmulAutotuneOperationSet { - fn new(lhs: JitTensor, rhs: JitTensor, out: JitTensor) -> Self { + fn new(lhs: JitTensor, rhs: JitTensor, out: JitTensor) -> Self { Self { key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())), lhs, rhs, out, + _e: PhantomData, } } } @@ -43,28 +47,36 @@ impl AutotuneOperationSet let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1); let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1); - let out = empty_device( + let out = empty_device::( self.out.client.clone(), self.out.device.clone(), self.out.shape.clone(), ); vec![ - Box::new(SimpleMatmul::new(lhs.clone(), rhs.clone(), out.clone())), - Box::new(SimpleMatmul16x16::new( + Box::new(SimpleMatmul::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(SimpleMatmul16x16::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), + Box::new(MatmulCube::::new( lhs.clone(), rhs.clone(), out.clone(), )), - Box::new(MatmulCube::new(lhs.clone(), rhs.clone(), out.clone())), ] } fn fastest(self: Box, fastest_index: usize) -> Box { match fastest_index { - 0 => Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)), - 1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)), - 2 => Box::new(MatmulCube::new(self.lhs, self.rhs, self.out)), + 0 => Box::new(SimpleMatmul::::new(self.lhs, self.rhs, self.out)), + 1 => Box::new(SimpleMatmul16x16::::new(self.lhs, self.rhs, self.out)), + 2 => Box::new(MatmulCube::::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -72,19 +84,23 @@ impl AutotuneOperationSet /// Executes autotune on matmul operations pub fn matmul_autotune( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { let client = lhs.client.clone(); - let output = init_matmul_output(&lhs, &rhs); + let output = init_matmul_output::(&lhs, &rhs); static TUNER: LocalTuner = local_tuner!(); TUNER.execute( &JitTuneId::new::(&lhs.device), &client, - Box::new(MatmulAutotuneOperationSet::new(lhs, rhs, output.clone())), + Box::new(MatmulAutotuneOperationSet::::new( + lhs, + rhs, + output.clone(), + )), ); output @@ -94,9 +110,10 @@ macro_rules! matmul_tune_ops { ($name:ident, $func:expr) => { #[derive(new, Debug)] pub(crate) struct $name { - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, + _e: PhantomData, } impl AutotuneOperation for $name { @@ -110,6 +127,7 @@ macro_rules! matmul_tune_ops { lhs: self.lhs.clone(), rhs: self.rhs.clone(), out: self.out.clone(), + _e: self._e, }) } } @@ -119,18 +137,18 @@ macro_rules! matmul_tune_ops { // Potentially better for small matrices. matmul_tune_ops!( SimpleMatmul, - crate::kernel::matmul::matmul_mem_coalescing_default + crate::kernel::matmul::matmul_mem_coalescing_default:: ); // Potentially better for small matrices. matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16) + crate::kernel::matmul::matmul_simple::(lhs, rhs, out, 16, 16) }); // Probably the fastest in the general case, without loop unrolling matmul_tune_ops!( MatmulCube, - |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { cubecl::linalg::matmul::launch_ref::( &lhs.client, lhs.as_handle_ref(), diff --git a/crates/burn-jit/src/kernel/matmul/utils.rs b/crates/burn-jit/src/kernel/matmul/utils.rs index 9722a088f4..fa65ce60d3 100644 --- a/crates/burn-jit/src/kernel/matmul/utils.rs +++ b/crates/burn-jit/src/kernel/matmul/utils.rs @@ -3,16 +3,13 @@ use burn_tensor::Shape; /// Creates an empty output tensor with matmul output shape pub fn init_matmul_output( - lhs: &JitTensor, - rhs: &JitTensor, -) -> JitTensor { - empty_device(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) + lhs: &JitTensor, + rhs: &JitTensor, +) -> JitTensor { + empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) } -pub(crate) fn shape_out( - lhs: &JitTensor, - rhs: &JitTensor, -) -> Shape { +pub(crate) fn shape_out(lhs: &JitTensor, rhs: &JitTensor) -> Shape { let ndims = lhs.shape.num_dims(); let mut shape_out = vec![0; ndims]; lhs.shape diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs index 233ccd2707..a41e9b7fd1 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs @@ -75,14 +75,14 @@ fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { } pub(crate) fn adaptive_avg_pool2d( - input: JitTensor, + input: JitTensor, output_size: [usize; 2], -) -> JitTensor { +) -> JitTensor { let [batch_size, channels, _, _] = input.shape.dims(); let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); let num_elems: usize = output_shape.num_elements(); - let output = empty_device(input.client.clone(), input.device.clone(), output_shape); + let output = empty_device::(input.client.clone(), input.device.clone(), output_shape); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); @@ -91,8 +91,8 @@ pub(crate) fn adaptive_avg_pool2d( &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); output diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs index f9eaa72bc4..a22bf4b68a 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -79,9 +79,9 @@ fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { } pub(crate) fn adaptive_avg_pool2d_backward( - x: JitTensor, - out_grad: JitTensor, -) -> JitTensor { + x: JitTensor, + out_grad: JitTensor, +) -> JitTensor { let output_shape = x.shape.clone(); let num_elems = output_shape.num_elements(); let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); @@ -90,6 +90,7 @@ pub(crate) fn adaptive_avg_pool2d_backward( x.device.clone(), output_shape, output_buffer, + x.dtype, ); let cube_dim = CubeDim::default(); @@ -99,8 +100,8 @@ pub(crate) fn adaptive_avg_pool2d_backward( &x.client, cube_count, cube_dim, - out_grad.as_tensor_arg(1), - output.as_tensor_arg(1), + out_grad.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); output diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d.rs index 0cda7292b2..eb900f7f5c 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d.rs @@ -58,12 +58,12 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { } pub(crate) fn avg_pool2d( - x: JitTensor, + x: JitTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, -) -> JitTensor { +) -> JitTensor { let [batch_size, channels, _, _] = x.shape.dims(); let dilation = 1; @@ -83,7 +83,7 @@ pub(crate) fn avg_pool2d( ); let shape_out = Shape::new([batch_size, channels, size_0, size_1]); - let output = empty_device(x.client.clone(), x.device.clone(), shape_out); + let output = empty_device::(x.client.clone(), x.device.clone(), shape_out); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -92,8 +92,8 @@ pub(crate) fn avg_pool2d( &x.client, cube_count, cube_dim, - x.as_tensor_arg(1), - output.as_tensor_arg(1), + x.as_tensor_arg::(1), + output.as_tensor_arg::(1), (), Pool2dDirectArgsLaunch::new( ScalarArg::new(stride[0] as u32), diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index cd2ab49531..bba68c7166 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -109,17 +109,17 @@ fn loop_ranges( } pub(crate) fn avg_pool2d_backward( - x: JitTensor, - grad: JitTensor, + x: JitTensor, + grad: JitTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, -) -> JitTensor { +) -> JitTensor { let grad = into_contiguous(grad); let dilation = 1; - let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); + let output = empty_device::(x.client.clone(), x.device.clone(), x.shape.clone()); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -128,8 +128,8 @@ pub(crate) fn avg_pool2d_backward( &grad.client, cube_count, cube_dim, - grad.as_tensor_arg(1), - output.as_tensor_arg(1), + grad.as_tensor_arg::(1), + output.as_tensor_arg::(1), PoolBackwardArgsLaunch::new( ScalarArg::new(stride[0] as i32), ScalarArg::new(stride[1] as i32), diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d.rs b/crates/burn-jit/src/kernel/pool/max_pool2d.rs index c7da70f3d1..3c24721f6b 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d.rs @@ -73,12 +73,12 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { } pub(crate) fn max_pool2d( - x: JitTensor, + x: JitTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> JitTensor { +) -> JitTensor { let [batch_size, channels, _, _] = x.shape.dims(); let size_0 = calculate_pool_output_size( @@ -97,7 +97,7 @@ pub(crate) fn max_pool2d( ); let shape_out = Shape::new([batch_size, channels, size_0, size_1]); - let output = empty_device(x.client.clone(), x.device.clone(), shape_out); + let output = empty_device::(x.client.clone(), x.device.clone(), shape_out); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -106,8 +106,8 @@ pub(crate) fn max_pool2d( &x.client, cube_count, cube_dim, - x.as_tensor_arg(1), - output.as_tensor_arg(1), + x.as_tensor_arg::(1), + output.as_tensor_arg::(1), (), Pool2dDirectArgsLaunch::new( ScalarArg::new(stride[0] as u32), @@ -125,12 +125,12 @@ pub(crate) fn max_pool2d( } pub(crate) fn max_pool2d_with_indices( - x: JitTensor, + x: JitTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> (JitTensor, JitTensor) { +) -> (JitTensor, JitTensor) { let [batch_size, channels, _, _] = x.shape.dims(); let size_0 = calculate_pool_output_size( @@ -149,8 +149,8 @@ pub(crate) fn max_pool2d_with_indices(x.client.clone(), x.device.clone(), shape_out.clone()); + let indices = empty_device::(x.client.clone(), x.device.clone(), shape_out); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -159,9 +159,9 @@ pub(crate) fn max_pool2d_with_indices(1), + output.as_tensor_arg::(1), + indices.as_tensor_arg::(1), Pool2dDirectArgsLaunch::new( ScalarArg::new(stride[0] as u32), ScalarArg::new(stride[1] as u32), diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 2f18a98d25..6da6e2b37c 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -74,18 +74,18 @@ fn loop_ranges( } pub(crate) fn max_pool2d_with_indices_backward( - x: JitTensor, - grad: JitTensor, - indices: JitTensor, + x: JitTensor, + grad: JitTensor, + indices: JitTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> JitTensor { +) -> JitTensor { let grad = into_contiguous(grad); let indices = into_contiguous(indices); - let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); + let output = empty_device::(x.client.clone(), x.device.clone(), x.shape.clone()); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); @@ -94,9 +94,9 @@ pub(crate) fn max_pool2d_with_indices_backward(1), + indices.as_tensor_arg::(1), + output.as_tensor_arg::(1), PoolBackwardArgsLaunch::new( ScalarArg::new(stride[0] as i32), ScalarArg::new(stride[1] as i32), diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index e9a45f3446..7ea4dc832c 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -12,9 +12,9 @@ pub(crate) fn random, R: JitRuntime, E: JitElement>( shape: Shape, device: &R::Device, prng: P, -) -> JitTensor { +) -> JitTensor { let client = R::client(device); - let output = empty_device(client.clone(), device.clone(), shape); + let output = empty_device::(client.clone(), device.clone(), shape); let seeds = get_seeds(); let args = prng.args(); @@ -25,7 +25,7 @@ pub(crate) fn random, R: JitRuntime, E: JitElement>( &client, cube_count, cube_dim, - output.as_tensor_arg(1), + output.as_tensor_arg::(1), ScalarArg::new(seeds[0]), ScalarArg::new(seeds[1]), ScalarArg::new(seeds[2]), diff --git a/crates/burn-jit/src/kernel/prng/bernoulli.rs b/crates/burn-jit/src/kernel/prng/bernoulli.rs index 2c4afb0811..6a54097d3f 100644 --- a/crates/burn-jit/src/kernel/prng/bernoulli.rs +++ b/crates/burn-jit/src/kernel/prng/bernoulli.rs @@ -59,6 +59,6 @@ pub fn random_bernoulli( shape: Shape, device: &R::Device, probability: E, -) -> JitTensor { +) -> JitTensor { random(shape, device, Bernoulli { probability }) } diff --git a/crates/burn-jit/src/kernel/prng/normal.rs b/crates/burn-jit/src/kernel/prng/normal.rs index cb599b8cc8..1453ce9715 100644 --- a/crates/burn-jit/src/kernel/prng/normal.rs +++ b/crates/burn-jit/src/kernel/prng/normal.rs @@ -87,6 +87,6 @@ pub fn random_normal( device: &R::Device, mean: E, std: E, -) -> JitTensor { +) -> JitTensor { random(shape, device, Normal { mean, std }) } diff --git a/crates/burn-jit/src/kernel/prng/uniform.rs b/crates/burn-jit/src/kernel/prng/uniform.rs index c1f80b5c61..18e8975377 100644 --- a/crates/burn-jit/src/kernel/prng/uniform.rs +++ b/crates/burn-jit/src/kernel/prng/uniform.rs @@ -71,7 +71,7 @@ pub fn random_uniform( device: &R::Device, lower_bound: E, upper_bound: E, -) -> JitTensor { +) -> JitTensor { random( shape, device, @@ -84,10 +84,10 @@ pub fn random_uniform( /// Pseudo-random generator for uniform distribution, based on /// another tensor. pub fn random_like_uniform( - tensor: &JitTensor, + tensor: &JitTensor, lower_bound: E, upper_bound: E, -) -> JitTensor { +) -> JitTensor { random_uniform( tensor.shape.clone(), &tensor.device, diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 61dc0733d8..4e2aa89cf7 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -126,10 +126,10 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( } pub(crate) fn dequantize_per_tensor( - tensor: JitTensor, - scale: JitTensor, - offset: Option>, -) -> JitTensor + tensor: JitTensor, + scale: JitTensor, + offset: Option>, +) -> JitTensor where R: JitRuntime, F: JitElement, @@ -158,8 +158,13 @@ where let shape_output = tensor.shape.clone(); let client = tensor.client.clone(); let handle = client.empty(num_out_elems * core::mem::size_of::()); - let output = - JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle); + let output = JitTensor::new_contiguous( + client.clone(), + tensor.device.clone(), + shape_output, + handle, + F::dtype(), + ); let dummy_array = vec![1; ndims]; if let Some(offset) = offset { @@ -168,11 +173,11 @@ where &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg(1), + output.as_tensor_arg::(1), vectorization_factor > 1, ) }; @@ -182,10 +187,10 @@ where &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg(1), + output.as_tensor_arg::(1), vectorization_factor > 1, ) }; @@ -195,7 +200,7 @@ where } /// Convert the tensor back to a higher precision data type. -pub fn dequantize(tensor: QJitTensor) -> JitTensor +pub fn dequantize(tensor: QJitTensor) -> JitTensor where R: JitRuntime, F: FloatElement, @@ -204,9 +209,11 @@ where match tensor.scheme { QuantizationScheme::PerTensorAffine(dtype) | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - dequantize_per_tensor(tensor.qtensor, tensor.qparams.scale, tensor.qparams.offset) - } + QuantizationType::QInt8 => dequantize_per_tensor::( + tensor.qtensor, + tensor.qparams.scale, + tensor.qparams.offset, + ), }, } } diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index 45119f7c5a..256ae0f418 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -131,10 +131,10 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( } pub(crate) fn quantize_per_tensor( - tensor: JitTensor, - scale: JitTensor, - offset: Option>, -) -> JitTensor + tensor: JitTensor, + scale: JitTensor, + offset: Option>, +) -> JitTensor where R: JitRuntime, F: JitElement, @@ -146,8 +146,13 @@ where let client = tensor.client.clone(); // Output tensor contains 4x less elements (four int8 values packed in a single u32) let handle = client.empty(usize::div_ceil(num_elems, 4) * core::mem::size_of::()); - let output = - JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle); + let output = JitTensor::new_contiguous( + client.clone(), + tensor.device.clone(), + shape_output, + handle, + burn_tensor::DType::U32, + ); // Force vectorization to process 4 quantized values packed for 1 output value let vectorization_factor: u8 = if num_elems < 4 { 1 } else { 4 }; @@ -162,13 +167,13 @@ where &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(i8::MIN as f32), ScalarArg::new(i8::MAX as f32), - output.as_tensor_arg(1), + output.as_tensor_arg::(1), vectorization_factor > 1, ) }; @@ -178,12 +183,12 @@ where &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(-i8::MAX as f32), ScalarArg::new(i8::MAX as f32), - output.as_tensor_arg(1), + output.as_tensor_arg::(1), vectorization_factor > 1, ) }; @@ -194,10 +199,10 @@ where /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( - tensor: JitTensor, + tensor: JitTensor, scheme: &QuantizationScheme, - qparams: JitQuantizationParameters, -) -> QJitTensor + qparams: JitQuantizationParameters, +) -> QJitTensor where R: JitRuntime, F: FloatElement, @@ -206,9 +211,11 @@ where let qtensor = match scheme { QuantizationScheme::PerTensorAffine(dtype) | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - quantize_per_tensor(tensor, qparams.scale.clone(), qparams.offset.clone()) - } + QuantizationType::QInt8 => quantize_per_tensor::( + tensor, + qparams.scale.clone(), + qparams.offset.clone(), + ), }, }; diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 163fcbffe2..5cb9ebfb7a 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -18,13 +18,13 @@ pub(crate) trait ReduceDimAlgorithm: /// Creates an empty output tensor with reduce output shape pub fn init_reduce_output( - input: &JitTensor, + input: &JitTensor, reduce_dim: usize, -) -> JitTensor { +) -> JitTensor { let mut shape_out = input.shape.clone(); shape_out.dims[reduce_dim] = 1; - empty_device(input.client.clone(), input.device.clone(), shape_out) + empty_device::(input.client.clone(), input.device.clone(), shape_out) } #[derive(Copy, Clone, Debug)] @@ -60,10 +60,10 @@ macro_rules! reduce_operation { /// Executes the reduce operation with the given strategy. pub fn $name( - tensor: JitTensor, + tensor: JitTensor, dim: usize, strategy: ReduceStrategy, - ) -> JitTensor { + ) -> JitTensor { match strategy { ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), diff --git a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs index d4936168f3..c001edca85 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs @@ -38,9 +38,9 @@ pub(crate) fn naive_reduce_dim_kernel, EI: Numeric, EO: N /// Executes the naive kernel for reduce dim pub fn reduce_dim_naive, R: JitRuntime, EI: JitElement, EO: JitElement>( - input: JitTensor, + input: JitTensor, dim: usize, -) -> JitTensor { +) -> JitTensor { let output = init_reduce_output::(&input, dim); let cube_dim = CubeDim::default(); @@ -51,8 +51,8 @@ pub fn reduce_dim_naive, R: JitRuntime, EI: JitElement, E &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), ScalarArg::new(dim as u32), ); } diff --git a/crates/burn-jit/src/kernel/reduce/prod.rs b/crates/burn-jit/src/kernel/reduce/prod.rs index 577feffc35..77227bae6f 100644 --- a/crates/burn-jit/src/kernel/reduce/prod.rs +++ b/crates/burn-jit/src/kernel/reduce/prod.rs @@ -5,11 +5,11 @@ use super::{prod_dim, ReduceStrategy}; /// Multiply all elements in the input buffer. pub fn prod( - input: JitTensor, + input: JitTensor, strategy: ReduceStrategy, -) -> JitTensor { +) -> JitTensor { let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle); - prod_dim(input, 0, strategy) + let input: JitTensor = + JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); + prod_dim::(input, 0, strategy) } diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs index dbf8ef7a65..1b2dcb356e 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs @@ -83,9 +83,9 @@ pub fn reduce_dim_shared< EI: JitElement, EO: JitElement, >( - input: JitTensor, + input: JitTensor, dim: usize, -) -> JitTensor { +) -> JitTensor { let output = init_reduce_output::(&input, dim); let num_elems_output = output.shape.num_elements(); @@ -105,8 +105,8 @@ pub fn reduce_dim_shared< &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), dim as u32, cube_dim.num_elems(), elems_per_thread, diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs index d6b7d15f7e..4e783e74e9 100644 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs @@ -86,9 +86,9 @@ pub fn reduce_dim_subcube< EI: JitElement, EO: JitElement, >( - input: JitTensor, + input: JitTensor, dim: usize, -) -> JitTensor { +) -> JitTensor { let topology = input.client.properties().hardware_properties(); if !input.client.properties().feature_enabled(Feature::Plane) @@ -122,8 +122,8 @@ pub fn reduce_dim_subcube< &input.client, cube_count, cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), dim as u32, subcube_size, elems_per_thread, diff --git a/crates/burn-jit/src/kernel/reduce/sum.rs b/crates/burn-jit/src/kernel/reduce/sum.rs index 9e7a5fe84d..fea80bccf0 100644 --- a/crates/burn-jit/src/kernel/reduce/sum.rs +++ b/crates/burn-jit/src/kernel/reduce/sum.rs @@ -5,11 +5,11 @@ use super::{sum_dim, ReduceStrategy}; /// Sum all elements in the input buffer. pub fn sum( - input: JitTensor, + input: JitTensor, strategy: ReduceStrategy, -) -> JitTensor { +) -> JitTensor { let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle); - sum_dim(input, 0, strategy) + let input: JitTensor = + JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); + sum_dim::(input, 0, strategy) } diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs index cf8a51f16d..f52bfd7ca0 100644 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ b/crates/burn-jit/src/kernel/reduce/tune/base.rs @@ -23,7 +23,7 @@ use super::create_key; /// dim to reduce, and product of others #[tune( operations(reduce_dim_naive, reduce_dim_shared, reduce_dim_subcube), - create_key = create_key, + create_key = create_key::, should_run = should_run )] pub fn reduce_dim_operations< @@ -33,9 +33,9 @@ pub fn reduce_dim_operations< EO: JitElement + Element, >( key: JitAutotuneKey, - input: JitTensor, + input: JitTensor, reduce_dim: usize, -) -> JitTensor { +) -> JitTensor { let random_bounds: (EI, EI) = ((-10.0).elem::(), (10.0).elem::()); let input = random_like_uniform(input, random_bounds.0, random_bounds.1); @@ -49,9 +49,9 @@ pub(crate) fn reduce_dim_autotune< EI: JitElement + Element, EO: JitElement + Element, >( - input: JitTensor, + input: JitTensor, reduce_dim: usize, -) -> JitTensor { +) -> JitTensor { let client = input.client.clone(); let id = JitTuneId::new::(&input.device); diff --git a/crates/burn-jit/src/kernel/reduce/tune/key.rs b/crates/burn-jit/src/kernel/reduce/tune/key.rs index 82209c3934..3634022bc7 100644 --- a/crates/burn-jit/src/kernel/reduce/tune/key.rs +++ b/crates/burn-jit/src/kernel/reduce/tune/key.rs @@ -18,7 +18,7 @@ pub struct ReduceAutotuneKey { } pub(crate) fn create_key( - input: &JitTensor, + input: &JitTensor, reduce_dim: &usize, ) -> JitAutotuneKey { let dims = &input.shape.dims; diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs index 1a767666e3..09f9c77689 100644 --- a/crates/burn-jit/src/kernel/unary.rs +++ b/crates/burn-jit/src/kernel/unary.rs @@ -45,9 +45,9 @@ pub(crate) fn unary_kernel>( } pub(crate) fn launch_unary, F>( - tensor: JitTensor, + tensor: JitTensor, options: F, -) -> JitTensor +) -> JitTensor where // Magic fix for lifetime, the closure is supposed to capture everything required to create the // argument. @@ -71,7 +71,7 @@ where &client, cube_count, cube_dim, - tensor.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), TensorArg::alias(0), options(&()), None, @@ -80,7 +80,7 @@ where tensor } else { - let output = empty_device( + let output = empty_device::( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), @@ -90,8 +90,8 @@ where &client, cube_count, CubeDim::default(), - tensor.as_tensor_arg(vectorization_factor), - output.as_tensor_arg(vectorization_factor), + tensor.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), options(&()), Some(ndims as u32), !is_contiguous, diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 740e4b3858..9f6b8f2234 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -1,20 +1,19 @@ use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime}; use burn_tensor::{Shape, TensorData}; use cubecl::CubeElement; -use std::marker::PhantomData; pub(crate) fn from_data( data: TensorData, device: &R::Device, -) -> JitTensor { +) -> JitTensor { let shape: Shape = (&data.shape).into(); let client = R::client(device); let buffer = client.create(data.convert::().as_bytes()); - JitTensor::new_contiguous(client, device.clone(), shape, buffer) + JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype()) } -pub(crate) async fn into_data(tensor: JitTensor) -> TensorData { +pub(crate) async fn into_data(tensor: JitTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; @@ -22,14 +21,14 @@ pub(crate) async fn into_data(tensor: JitTensor(tensor: JitTensor) -> TensorData { +pub(crate) fn into_data_sync(tensor: JitTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one(tensor.handle.binding()); TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) } -pub(crate) async fn bool_into_data(tensor: JitTensor) -> TensorData { +pub(crate) async fn bool_into_data(tensor: JitTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; TensorData::new( @@ -38,10 +37,7 @@ pub(crate) async fn bool_into_data(tensor: JitTensor) -> ) } -pub(crate) fn to_device( - tensor: JitTensor, - device: &R::Device, -) -> JitTensor { +pub(crate) fn to_device(tensor: JitTensor, device: &R::Device) -> JitTensor { if &tensor.device == device { return tensor; } @@ -53,28 +49,25 @@ pub(crate) fn to_device( pub(crate) fn empty( shape: Shape, device: &R::Device, -) -> JitTensor { +) -> JitTensor { let client = R::client(device); let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - JitTensor::new_contiguous(client, device.clone(), shape, buffer) + JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype()) } -pub(crate) fn swap_dims( - mut tensor: JitTensor, +pub(crate) fn swap_dims( + mut tensor: JitTensor, dim1: usize, dim2: usize, -) -> JitTensor { +) -> JitTensor { tensor.strides.swap(dim1, dim2); tensor.shape.dims.swap(dim1, dim2); tensor } -pub(crate) fn permute( - mut tensor: JitTensor, - axes: &[usize], -) -> JitTensor { +pub(crate) fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTensor { // remap strides tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect(); @@ -83,10 +76,7 @@ pub(crate) fn permute( tensor } -pub(crate) fn expand( - tensor: JitTensor, - target_shape: Shape, -) -> JitTensor { +pub(crate) fn expand(tensor: JitTensor, target_shape: Shape) -> JitTensor { let ndims_in = tensor.shape.num_dims(); let ndims_out = target_shape.num_dims(); @@ -132,16 +122,19 @@ pub(crate) fn expand( shape: target_shape, strides: new_strides, handle: tensor.handle, - elem: PhantomData, + dtype: tensor.dtype, } } -pub(crate) fn reshape( - tensor: JitTensor, - shape: Shape, -) -> JitTensor { +pub(crate) fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor { // TODO: Not force standard layout all the time (improve performance). let tensor = kernel::into_contiguous(tensor); - JitTensor::new_contiguous(tensor.client, tensor.device, shape, tensor.handle) + JitTensor::new_contiguous( + tensor.client, + tensor.device, + shape, + tensor.handle, + tensor.dtype, + ) } diff --git a/crates/burn-jit/src/ops/bool_ops.rs b/crates/burn-jit/src/ops/bool_ops.rs index 8fcc12536d..036913e88d 100644 --- a/crates/burn-jit/src/ops/bool_ops.rs +++ b/crates/burn-jit/src/ops/bool_ops.rs @@ -12,11 +12,7 @@ where I: IntElement, { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::empty(shape, device) - } - - fn bool_shape(tensor: &BoolTensor) -> Shape { - tensor.shape.clone() + super::empty::(shape, device) } async fn bool_into_data(tensor: BoolTensor) -> TensorData { @@ -25,11 +21,11 @@ where fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let data: TensorData = TensorData::new(data.iter::().collect(), data.shape); - super::from_data(data, device) + super::from_data::(data, device) } fn bool_into_int(tensor: BoolTensor) -> IntTensor { - kernel::bool_cast(tensor) + kernel::bool_cast::(tensor) } fn bool_device(tensor: &BoolTensor) -> Device { @@ -45,7 +41,7 @@ where } fn bool_slice(tensor: BoolTensor, ranges: &[Range]) -> BoolTensor { - kernel::slice(tensor, ranges) + kernel::slice::(tensor, ranges) } fn bool_slice_assign( @@ -53,19 +49,19 @@ where ranges: &[Range], value: BoolTensor, ) -> BoolTensor { - kernel::slice_assign(tensor, ranges, value) + kernel::slice_assign::(tensor, ranges, value) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { - kernel::equal(lhs, rhs) + kernel::equal::(lhs, rhs) } fn bool_not(tensor: BoolTensor) -> BoolTensor { - kernel::equal_elem(tensor, 0) + kernel::equal_elem::(tensor, 0) } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - kernel::bool_cast(tensor) + kernel::bool_cast::(tensor) } fn bool_swap_dims(mut tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { @@ -76,7 +72,7 @@ where } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { - kernel::repeat_dim(tensor, dim, times) + kernel::repeat_dim::(tensor, dim, times) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { @@ -88,6 +84,6 @@ where } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { - kernel::flip(tensor, axes) + kernel::flip::(tensor, axes) } } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 32282b705a..52b013ec0e 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -2,12 +2,13 @@ use super::{expand, numeric, permute}; use crate::kernel::matmul::{matmul, MatmulStrategy}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; -use crate::JitBackend; +use crate::{execute_with_dtype, JitBackend}; use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; -use burn_tensor::ElementConversion; use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData}; +use burn_tensor::{DType, ElementConversion, FloatDType}; use cubecl::prelude::*; +use half::{bf16, f16}; use std::ops::Range; impl FloatTensorOps for JitBackend @@ -17,7 +18,7 @@ where I: IntElement, { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { - super::from_data(data, device) + super::from_data::(data, device) } fn float_random( @@ -26,23 +27,23 @@ where device: &Device, ) -> FloatTensor { match distribution { - Distribution::Default => random_uniform(shape, device, 0.elem(), 1.elem()), + Distribution::Default => random_uniform(shape, device, 0.elem::(), 1.elem()), Distribution::Uniform(low, high) => { - random_uniform(shape, device, low.elem(), high.elem()) + random_uniform(shape, device, low.elem::(), high.elem()) } - Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem()), + Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::()), Distribution::Normal(mean, std) => { - random_normal(shape, device, mean.elem(), std.elem()) + random_normal(shape, device, mean.elem::(), std.elem()) } } } - fn float_shape(tensor: &FloatTensor) -> Shape { - tensor.shape.clone() - } - async fn float_into_data(tensor: FloatTensor) -> TensorData { - super::into_data(tensor).await + execute_with_dtype!( + float(tensor.dtype), + E, + super::into_data::(tensor).await + ) } fn float_device(tensor: &FloatTensor) -> Device { @@ -54,19 +55,27 @@ where } fn float_empty(shape: Shape, device: &Device) -> FloatTensor { - super::empty(shape, device) + super::empty::(shape, device) } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - numeric::add(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + numeric::add::(lhs, rhs) + ) } fn float_add_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { - numeric::add_scalar(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + numeric::add_scalar::(lhs, rhs.elem()) + ) } fn float_zeros(shape: Shape, device: &Device) -> FloatTensor { - numeric::zeros(shape, device) + numeric::zeros::(shape, device) } fn float_full( @@ -74,47 +83,83 @@ where fill_value: FloatElem, device: &R::Device, ) -> FloatTensor { - numeric::full(shape, device, fill_value) + numeric::full::(shape, device, fill_value) } fn float_ones(shape: Shape, device: &Device) -> FloatTensor { - numeric::ones(shape, device) + numeric::ones::(shape, device) } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - numeric::sub(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + numeric::sub::(lhs, rhs) + ) } fn float_sub_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { - numeric::sub_scalar(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + numeric::sub_scalar::(lhs, rhs.elem()) + ) } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - numeric::mul(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + numeric::mul::(lhs, rhs) + ) } fn float_mul_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { - numeric::mul_scalar(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + numeric::mul_scalar::(lhs, rhs.elem()) + ) } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - numeric::div(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + numeric::div::(lhs, rhs) + ) } fn float_div_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { - numeric::div_scalar(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + numeric::div_scalar::(lhs, rhs.elem()) + ) } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - numeric::remainder(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + numeric::remainder::(lhs, rhs) + ) } fn float_remainder_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { - numeric::remainder_scalar(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + numeric::remainder_scalar::(lhs, rhs.elem()) + ) } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - matmul(lhs, rhs, MatmulStrategy::default()) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + matmul::(lhs, rhs, MatmulStrategy::default()) + ) } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { @@ -130,7 +175,11 @@ where tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { - kernel::gather(dim, tensor, indices) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::gather::(dim, tensor, indices) + ) } fn float_scatter( @@ -139,7 +188,11 @@ where indices: IntTensor, value: FloatTensor, ) -> FloatTensor { - kernel::scatter(dim, tensor, indices, value) + execute_with_dtype!( + float(tensor.dtype, value.dtype), + E, + kernel::scatter::(dim, tensor, indices, value) + ) } fn float_select( @@ -147,7 +200,11 @@ where dim: usize, indices: IntTensor, ) -> FloatTensor { - kernel::select(tensor, dim, indices) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::select::(tensor, dim, indices) + ) } fn float_select_assign( @@ -156,11 +213,19 @@ where indices: IntTensor, value: FloatTensor, ) -> FloatTensor { - kernel::select_assign(tensor, dim, indices, value) + execute_with_dtype!( + float(tensor.dtype, value.dtype), + E, + kernel::select_assign::(tensor, dim, indices, value) + ) } fn float_slice(tensor: FloatTensor, ranges: &[Range]) -> FloatTensor { - kernel::slice(tensor, ranges) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::slice::(tensor, ranges) + ) } fn float_slice_assign( @@ -168,7 +233,11 @@ where ranges: &[Range], value: FloatTensor, ) -> FloatTensor { - kernel::slice_assign(tensor, ranges, value) + execute_with_dtype!( + float(tensor.dtype, value.dtype), + E, + kernel::slice_assign::(tensor, ranges, value) + ) } fn float_mask_where( @@ -176,7 +245,11 @@ where mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { - kernel::mask_where_auto(tensor, mask, value) + execute_with_dtype!( + float(tensor.dtype, value.dtype), + E, + kernel::mask_where_auto::(tensor, mask, value) + ) } fn float_mask_fill( @@ -184,209 +257,333 @@ where mask: BoolTensor, value: FloatElem, ) -> FloatTensor { - kernel::mask_fill_auto(tensor, mask, value) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::mask_fill_auto::(tensor, mask, value.elem()) + ) } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { - kernel::equal(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + kernel::equal::(lhs, rhs) + ) } fn float_equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - kernel::equal_elem(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + kernel::equal_elem::(lhs, rhs.elem()) + ) } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { - kernel::greater(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + kernel::greater::(lhs, rhs) + ) } fn float_greater_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - kernel::greater_elem(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + kernel::greater_elem::(lhs, rhs.elem()) + ) } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { - kernel::greater_equal(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + kernel::greater_equal::(lhs, rhs) + ) } fn float_greater_equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - kernel::greater_equal_elem(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + kernel::greater_equal_elem::(lhs, rhs.elem()) + ) } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { - kernel::lower(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + kernel::lower::(lhs, rhs) + ) } fn float_lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - kernel::lower_elem(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + kernel::lower_elem::(lhs, rhs.elem()) + ) } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { - kernel::lower_equal(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype, rhs.dtype), + E, + kernel::lower_equal::(lhs, rhs) + ) } fn float_lower_equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { - kernel::lower_equal_elem(lhs, rhs) + execute_with_dtype!( + float(lhs.dtype), + E, + kernel::lower_equal_elem::(lhs, rhs.elem()) + ) } fn float_sum(tensor: FloatTensor) -> FloatTensor { - reduce::sum(tensor, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::sum::(tensor, Default::default()) + ) } fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - reduce::sum_dim(tensor, dim, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::sum_dim::(tensor, dim, Default::default()) + ) } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - reduce::mean_dim(tensor, dim, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::mean_dim::(tensor, dim, Default::default()) + ) } fn float_prod(tensor: FloatTensor) -> FloatTensor { - reduce::prod(tensor, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::prod::(tensor, Default::default()) + ) } fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - reduce::prod_dim(tensor, dim, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::prod_dim::(tensor, dim, Default::default()) + ) } fn float_exp(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::exp(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::exp(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_log(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::log(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log1p(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::log1p(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_powf_scalar(lhs: FloatTensor, rhs: f32) -> FloatTensor { - unary_op!(float(lhs, rhs.elem::()) => |context, tensor, scalar| { - #[cube] - fn execute(input: Line, scalar: C) -> Line { - Line::powf(input, Line::new(scalar)) - } - execute::expand::(context, tensor, scalar) - }) + execute_with_dtype!( + float(lhs.dtype), + F, + unary_op!(float(lhs, rhs.elem::()) => |context, tensor, scalar| { + #[cube] + fn execute(input: Line, scalar: C) -> Line { + Line::powf(input, Line::new(scalar)) + } + execute::expand::(context, tensor, scalar) + }) + ) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sqrt(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::sqrt(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_abs(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::abs(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::abs(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_cos(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::cos(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::cos(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_sin(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sin(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::sin(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::tanh(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::tanh(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_round(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::round(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::round(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_floor(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::floor(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::floor(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::ceil(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::ceil(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_erf(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::erf(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::erf(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - reduce::argmax(tensor, dim, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::argmax::(tensor, dim, Default::default()) + ) } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - reduce::argmin(tensor, dim, Default::default()) + execute_with_dtype!( + float(tensor.dtype), + E, + reduce::argmin::(tensor, dim, Default::default()) + ) } fn float_into_int(tensor: FloatTensor) -> IntTensor { - kernel::cast(tensor) + execute_with_dtype!(float(tensor.dtype), E, kernel::cast::(tensor)) } fn float_clamp( @@ -394,25 +591,37 @@ where min: FloatElem, max: FloatElem, ) -> FloatTensor { - kernel::clamp(tensor, min, max) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::clamp::(tensor, min.elem(), max.elem()) + ) } fn float_recip(tensor: FloatTensor) -> FloatTensor { - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::recip(input) - } - execute::expand::(context, tensor) - }) + execute_with_dtype!( + float(tensor.dtype), + F, + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::recip(input) + } + execute::expand::(context, tensor) + }) + ) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { - kernel::repeat_dim(tensor, dim, times) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::repeat_dim::(tensor, dim, times) + ) } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { - numeric::pow(lhs, rhs) + execute_with_dtype!(float(lhs.dtype), E, numeric::pow::(lhs, rhs)) } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { @@ -424,13 +633,28 @@ where } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { - kernel::flip(tensor, axes) - } - - fn float_cast( - _tensor: FloatTensor, - _dtype: burn_tensor::FloatDType, - ) -> FloatTensor { - todo!() + execute_with_dtype!(float(tensor.dtype), E, kernel::flip::(tensor, axes)) + } + + fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { + match (tensor.dtype, dtype) { + (DType::F64, FloatDType::F64) + | (DType::F32, FloatDType::F32) + | (DType::BF16, FloatDType::BF16) + | (DType::F16, FloatDType::F16) => tensor, + (DType::F64, FloatDType::F32) => kernel::cast::(tensor), + (DType::F64, FloatDType::F16) => kernel::cast::(tensor), + (DType::F64, FloatDType::BF16) => kernel::cast::(tensor), + (DType::F32, FloatDType::F64) => kernel::cast::(tensor), + (DType::F32, FloatDType::F16) => kernel::cast::(tensor), + (DType::F32, FloatDType::BF16) => kernel::cast::(tensor), + (DType::F16, FloatDType::F64) => kernel::cast::(tensor), + (DType::F16, FloatDType::F32) => kernel::cast::(tensor), + (DType::F16, FloatDType::BF16) => kernel::cast::(tensor), + (DType::BF16, FloatDType::F64) => kernel::cast::(tensor), + (DType::BF16, FloatDType::F32) => kernel::cast::(tensor), + (DType::BF16, FloatDType::F16) => kernel::cast::(tensor), + _ => unimplemented!("Unsupported floating point type cast"), + } } } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 9a7dc5f9f1..cb6603bf80 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -15,19 +15,15 @@ where I: IntElement, { fn int_empty(shape: Shape, device: &Device) -> IntTensor { - super::empty(shape, device) - } - - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape.clone() + super::empty::(shape, device) } async fn int_into_data(tensor: IntTensor) -> TensorData { - super::into_data(tensor).await + super::into_data::(tensor).await } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { - super::from_data(data, device) + super::from_data::(data, device) } fn int_device(tensor: &IntTensor) -> Device { @@ -43,7 +39,7 @@ where } fn int_slice(tensor: IntTensor, ranges: &[Range]) -> IntTensor { - kernel::slice(tensor, ranges) + kernel::slice::(tensor, ranges) } fn int_slice_assign( @@ -51,7 +47,7 @@ where ranges: &[Range], value: IntTensor, ) -> IntTensor { - kernel::slice_assign(tensor, ranges, value) + kernel::slice_assign::(tensor, ranges, value) } fn int_mask_where( @@ -59,7 +55,7 @@ where mask: BoolTensor, value: IntTensor, ) -> IntTensor { - kernel::mask_where_auto(tensor, mask, value) + kernel::mask_where_auto::(tensor, mask, value) } fn int_mask_fill( @@ -75,7 +71,7 @@ where tensor: IntTensor, indices: IntTensor, ) -> IntTensor { - kernel::gather(dim, tensor, indices) + kernel::gather::(dim, tensor, indices) } fn int_scatter( @@ -84,7 +80,7 @@ where indices: IntTensor, value: IntTensor, ) -> IntTensor { - kernel::scatter(dim, tensor, indices, value) + kernel::scatter::(dim, tensor, indices, value) } fn int_select( @@ -92,7 +88,7 @@ where dim: usize, indices: IntTensor, ) -> IntTensor { - kernel::select(tensor, dim, indices) + kernel::select::(tensor, dim, indices) } fn int_select_assign( @@ -101,123 +97,123 @@ where indices: IntTensor, value: IntTensor, ) -> IntTensor { - kernel::select_assign(tensor, dim, indices, value) + kernel::select_assign::(tensor, dim, indices, value) } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::equal(lhs, rhs) + kernel::equal::(lhs, rhs) } fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::equal_elem(lhs, rhs) + kernel::equal_elem::(lhs, rhs) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::greater(lhs, rhs) + kernel::greater::(lhs, rhs) } fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::greater_elem(lhs, rhs) + kernel::greater_elem::(lhs, rhs) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::greater_equal(lhs, rhs) + kernel::greater_equal::(lhs, rhs) } fn int_greater_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::greater_equal_elem(lhs, rhs) + kernel::greater_equal_elem::(lhs, rhs) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::lower(lhs, rhs) + kernel::lower::(lhs, rhs) } fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::lower_elem(lhs, rhs) + kernel::lower_elem::(lhs, rhs) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::lower_equal(lhs, rhs) + kernel::lower_equal::(lhs, rhs) } fn int_lower_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::lower_equal_elem(lhs, rhs) + kernel::lower_equal_elem::(lhs, rhs) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - numeric::add(lhs, rhs) + numeric::add::(lhs, rhs) } fn int_add_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - numeric::add_scalar(lhs, rhs) + numeric::add_scalar::(lhs, rhs) } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - numeric::sub(lhs, rhs) + numeric::sub::(lhs, rhs) } fn int_sub_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - numeric::sub_scalar(lhs, rhs) + numeric::sub_scalar::(lhs, rhs) } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - numeric::mul(lhs, rhs) + numeric::mul::(lhs, rhs) } fn int_mul_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - numeric::mul_scalar(lhs, rhs) + numeric::mul_scalar::(lhs, rhs) } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - numeric::div(lhs, rhs) + numeric::div::(lhs, rhs) } fn int_div_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - numeric::div_scalar(lhs, rhs) + numeric::div_scalar::(lhs, rhs) } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - numeric::remainder(lhs, rhs) + numeric::remainder::(lhs, rhs) } fn int_remainder_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - numeric::remainder_scalar(lhs, rhs) + numeric::remainder_scalar::(lhs, rhs) } fn int_zeros(shape: Shape, device: &Device) -> IntTensor { - numeric::zeros(shape, device) + numeric::zeros::(shape, device) } fn int_ones(shape: Shape, device: &Device) -> IntTensor { - numeric::ones(shape, device) + numeric::ones::(shape, device) } fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::reduce::sum(tensor, Default::default()) + kernel::reduce::sum::(tensor, Default::default()) } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::sum_dim(tensor, dim, Default::default()) + kernel::reduce::sum_dim::(tensor, dim, Default::default()) } fn int_prod(tensor: IntTensor) -> IntTensor { - kernel::reduce::prod(tensor, Default::default()) + kernel::reduce::prod::(tensor, Default::default()) } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::prod_dim(tensor, dim, Default::default()) + kernel::reduce::prod_dim::(tensor, dim, Default::default()) } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::mean_dim(tensor, dim, Default::default()) + kernel::reduce::mean_dim::(tensor, dim, Default::default()) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax(tensor, dim, Default::default()) + kernel::reduce::argmax::(tensor, dim, Default::default()) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin(tensor, dim, Default::default()) + kernel::reduce::argmin::(tensor, dim, Default::default()) } fn int_clamp( @@ -225,7 +221,7 @@ where min: IntElem, max: IntElem, ) -> IntTensor { - kernel::clamp(tensor, min, max) + kernel::clamp::(tensor, min, max) } fn int_abs(tensor: IntTensor) -> IntTensor { @@ -239,7 +235,7 @@ where } fn int_into_float(tensor: IntTensor) -> FloatTensor { - kernel::cast(tensor) + kernel::cast::(tensor) } fn int_swap_dims(mut tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { @@ -250,7 +246,7 @@ where } fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { - kernel::repeat_dim(tensor, dim, times) + kernel::repeat_dim::(tensor, dim, times) } fn int_random( @@ -259,17 +255,17 @@ where device: &Device, ) -> IntTensor { let float_tensor = match distribution { - Distribution::Default => random_uniform(shape, device, 0.elem::(), 255.elem()), + Distribution::Default => random_uniform(shape, device, 0.elem::(), 255.elem()), Distribution::Uniform(low, high) => { - random_uniform(shape, device, low.elem(), high.elem()) + random_uniform(shape, device, low.elem::(), high.elem()) } - Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem()), + Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::()), Distribution::Normal(mean, std) => { - random_normal(shape, device, mean.elem(), std.elem()) + random_normal(shape, device, mean.elem::(), std.elem()) } }; - kernel::cast(float_tensor) + kernel::cast::(float_tensor) } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { @@ -281,6 +277,6 @@ where } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { - kernel::flip(tensor, axes) + kernel::flip::(tensor, axes) } } diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index 1b6a43a4ca..5539dfc9f2 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -46,7 +46,15 @@ where output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { - kernel::conv::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options) + kernel::conv::deform_conv2d_backward::( + x, + offset, + weight, + mask, + bias, + output_grad, + options, + ) } fn conv3d( @@ -55,7 +63,7 @@ where bias: Option>, options: ConvOptions<3>, ) -> FloatTensor { - kernel::conv::conv3d(x, weight, bias, options) + kernel::conv::conv3d::(x, weight, bias, options) } fn conv_transpose2d( @@ -79,7 +87,7 @@ where bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor { - kernel::conv::conv_transpose3d(x, weight, bias, options) + kernel::conv::conv_transpose3d::(x, weight, bias, options) } fn avg_pool2d( @@ -89,7 +97,7 @@ where padding: [usize; 2], count_include_pad: bool, ) -> FloatTensor { - kernel::pool::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) + kernel::pool::avg_pool2d::(x, kernel_size, stride, padding, count_include_pad) } fn avg_pool2d_backward( @@ -100,7 +108,14 @@ where padding: [usize; 2], count_include_pad: bool, ) -> FloatTensor { - kernel::pool::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) + kernel::pool::avg_pool2d_backward::( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ) } fn max_pool2d( @@ -110,7 +125,7 @@ where padding: [usize; 2], dilation: [usize; 2], ) -> FloatTensor { - kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation) + kernel::pool::max_pool2d::(x, kernel_size, stride, padding, dilation) } fn max_pool2d_with_indices( @@ -120,8 +135,13 @@ where padding: [usize; 2], dilation: [usize; 2], ) -> MaxPool2dWithIndices { - let (output, indices) = - kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); + let (output, indices) = kernel::pool::max_pool2d_with_indices::( + x, + kernel_size, + stride, + padding, + dilation, + ); MaxPool2dWithIndices::new(output, indices) } @@ -135,7 +155,7 @@ where output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward { - MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward( + MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::( x, output_grad, indices, @@ -147,14 +167,14 @@ where } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { - kernel::pool::adaptive_avg_pool2d(x, output_size) + kernel::pool::adaptive_avg_pool2d::(x, output_size) } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { - kernel::pool::adaptive_avg_pool2d_backward(x, grad) + kernel::pool::adaptive_avg_pool2d_backward::(x, grad) } fn interpolate( @@ -162,7 +182,7 @@ where output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { - kernel::interpolate::interpolate(x, output_size, options) + kernel::interpolate::interpolate::(x, output_size, options) } fn interpolate_backward( @@ -171,6 +191,6 @@ where output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { - kernel::interpolate::interpolate_backward(x, grad, output_size, options) + kernel::interpolate::interpolate_backward::(x, grad, output_size, options) } } diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index e786403a16..2519060cba 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -12,7 +12,7 @@ pub fn full( shape: Shape, device: &R::Device, value: E, -) -> JitTensor { +) -> JitTensor { let client = R::client(device); full_device::(client, shape, device.clone(), value) @@ -23,9 +23,9 @@ pub fn full_device( shape: Shape, device: R::Device, value: E, -) -> JitTensor { +) -> JitTensor { let ndims = shape.num_dims(); - let empty = empty_device(client, device, shape); + let empty = empty_device::(client, device, shape); #[cube(launch)] pub fn full_kernel(tensor: &mut Tensor, value: C) { @@ -48,28 +48,28 @@ pub fn full_device( &empty.client, cube_count, cube_dim, - empty.as_tensor_arg(vectorization_factor), + empty.as_tensor_arg::(vectorization_factor), ScalarArg::new(value), ); empty } -pub fn zeros(shape: Shape, device: &R::Device) -> JitTensor { +pub fn zeros(shape: Shape, device: &R::Device) -> JitTensor { let client = R::client(device); - zeros_device(client, device.clone(), shape) + zeros_device::(client, device.clone(), shape) } pub fn zeros_device( client: ComputeClient, device: R::Device, shape: Shape, -) -> JitTensor { +) -> JitTensor { full_device::(client, shape, device, 0.elem()) } -pub fn ones(shape: Shape, device: &R::Device) -> JitTensor { +pub fn ones(shape: Shape, device: &R::Device) -> JitTensor { let client = R::client(device); ones_device::(client, device.clone(), shape) @@ -79,7 +79,7 @@ pub fn ones_device( client: ComputeClient, device: R::Device, shape: Shape, -) -> JitTensor { +) -> JitTensor { full_device::(client, shape, device, 1.elem()) } @@ -87,73 +87,55 @@ pub fn empty_device( client: ComputeClient, device: R::Device, shape: Shape, -) -> JitTensor { +) -> JitTensor { let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - JitTensor::new_contiguous(client, device, shape, buffer) + JitTensor::new_contiguous(client, device, shape, buffer, E::dtype()) } -pub fn add( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn add(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } -pub fn add_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn add_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } -pub fn sub( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn sub(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } -pub fn sub_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn sub_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } -pub fn mul( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn mul(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } -pub fn mul_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn mul_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } -pub fn div( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn div(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } -pub fn div_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn div_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } pub fn remainder( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { launch_binop::(lhs, rhs) } -pub fn remainder_scalar( - lhs: JitTensor, - rhs: E, -) -> JitTensor { +pub fn remainder_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } -pub fn pow( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index c6fc21e180..e5eb4005a6 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -20,11 +20,11 @@ fn packed_tensor>( data: &[u8], shape: S, device: &R::Device, -) -> JitTensor { +) -> JitTensor { let client = R::client(device); let buffer = client.create(data); - JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer) + JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer, DType::U32) } impl QTensorOps for JitBackend @@ -39,7 +39,7 @@ where QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { // Convert quantized values to packed u32s - let qparams = data.get_q_params().unwrap(); + let qparams = data.get_q_params::().unwrap(); QJitTensor { qtensor: packed_tensor(data.values_as_bytes(), data.shape.clone(), device), scheme, @@ -63,11 +63,11 @@ where scheme: &QuantizationScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { - kernel::quantization::quantize(tensor, scheme, qparams.into()) + kernel::quantization::quantize::(tensor, scheme, qparams.into()) } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { - kernel::quantization::dequantize(tensor) + kernel::quantization::dequantize::(tensor) } fn q_shape(tensor: &QuantizedTensor) -> Shape { diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 7545a27379..61c181cbb4 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -70,7 +70,7 @@ macro_rules! kernel_source { /// | (D + 1)..(2 * D + 1) | rhs strides | /// | (2 * D + 1)..(3 * D + 1) | lhs shape | /// | (3 * D + 1)..(4 * D + 1) | rhs shape | -pub fn build_info(tensors: &[&JitTensor]) -> Vec { +pub fn build_info(tensors: &[&JitTensor]) -> Vec { let ndims = tensors[0].shape.num_dims(); let mut info: Vec = vec![0; tensors.len() * 2 * ndims + 1]; info[0] = ndims as u32; diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 8ee142d25e..dca2c78c29 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,7 +1,7 @@ use crate::element::JitElement; use crate::kernel::{launch_unary, unary_op, UnaryOp}; use crate::JitRuntime; -use burn_tensor::Shape; +use burn_tensor::{DType, Shape, TensorMetadata}; use cubecl::client::ComputeClient; use cubecl::frontend::Numeric; use cubecl::linalg::tensor::TensorHandle; @@ -11,11 +11,7 @@ use std::marker::PhantomData; /// The basic tensor primitive struct. #[derive(new)] -pub struct JitTensor -where - R: JitRuntime, - E: JitElement, -{ +pub struct JitTensor { /// Compute client for the [runtime](JitRuntime). pub client: ComputeClient, /// The buffer where the data are stored. @@ -26,19 +22,18 @@ where pub device: R::Device, /// The strides of the tensor. pub strides: Vec, - pub(crate) elem: PhantomData, + pub(crate) dtype: DType, } -impl From> for TensorHandle { - fn from(val: JitTensor) -> Self { +impl From> for TensorHandle { + fn from(val: JitTensor) -> Self { TensorHandle::new(val.shape.dims.to_vec(), val.strides.to_vec(), val.handle) } } -impl core::fmt::Debug for JitTensor +impl core::fmt::Debug for JitTensor where R: JitRuntime, - E: JitElement, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( @@ -46,16 +41,15 @@ where self.shape, self.device, self.strides, - E::type_name(), + self.dtype.name(), R::name(), )) } } -impl Clone for JitTensor +impl Clone for JitTensor where R: JitRuntime, - E: JitElement, { fn clone(&self) -> Self { Self { @@ -64,15 +58,132 @@ where shape: self.shape.clone(), device: self.device.clone(), strides: self.strides.clone(), - elem: PhantomData, + dtype: self.dtype, } } } -impl JitTensor +impl TensorMetadata for JitTensor { + fn dtype(&self) -> DType { + match self.dtype { + // NOTE: bool tensors are stored as u32, we currently make this assumption + // since `TensorMetadata::dtype()` is used for display purposes only at this time. + DType::U32 => DType::Bool, + _ => self.dtype, + } + } + + fn shape(&self) -> Shape { + self.shape.clone() + } +} + +/// Macro to execute a kernel/operation for a given element type. +/// +/// # Panics +/// Since there is no automatic type cast at this time, binary operations for different +/// floating point precision data types will panic with a data type mismatch. +#[macro_export] +macro_rules! execute_with_dtype { + (float($dtype:expr), $element:ident, $op:expr) => {{ + match $dtype { + burn_tensor::DType::F64 => { + type $element = f64; + $op + } + burn_tensor::DType::F32 => { + type $element = f32; + $op + } + burn_tensor::DType::F16 => { + type $element = half::f16; + $op + } + burn_tensor::DType::BF16 => { + type $element = half::bf16; + $op + } + _ => unimplemented!("Unsupported dtype"), + } + }}; + + (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{ + // NOTE: might be better for floating point binary operations to return a Result instead? + if $lhs_dtype != $rhs_dtype { + panic!( + "Data type mismatch (lhs: {:?}, rhs: {:?})", + $lhs_dtype, $rhs_dtype + ); + } + execute_with_dtype!(float($lhs_dtype), $element, $op) + }}; + ($dtype:expr, $element:ident, $op:expr) => {{ + match $dtype { + burn_tensor::DType::F64 => { + type $element = f64; + $op + } + burn_tensor::DType::F32 => { + type $element = f32; + $op + } + burn_tensor::DType::F16 => { + type $element = half::f16; + $op + } + burn_tensor::DType::BF16 => { + type $element = half::bf16; + $op + } + burn_tensor::DType::U64 => { + type $element = u64; + $op + } + burn_tensor::DType::U32 => { + type $element = u32; + $op + } + burn_tensor::DType::U16 => { + type $element = u16; + $op + } + burn_tensor::DType::U8 => { + type $element = u8; + $op + } + burn_tensor::DType::I64 => { + type $element = i64; + $op + } + burn_tensor::DType::I32 => { + type $element = i32; + $op + } + burn_tensor::DType::I16 => { + type $element = i16; + $op + } + burn_tensor::DType::I8 => { + type $element = i8; + $op + } + // NOTE: bool and qfloat dtypes are actually represented as u32 + // burn_tensor::DType::Bool => { + // type $element = u32; + // $op + // } + // burn_tensor::DType::QFloat(_) => { + // type $element = u32; + // $op + // } + _ => unimplemented!("Unsupported dtype"), + } + }}; +} + +impl JitTensor where R: JitRuntime, - E: JitElement, { /// Create a new tensor with a contiguous memory layout. pub fn new_contiguous( @@ -80,6 +191,7 @@ where device: R::Device, shape: Shape, handle: Handle, + dtype: DType, ) -> Self { let ndims = shape.num_dims(); let mut strides = vec![0; ndims]; @@ -101,7 +213,7 @@ where shape, strides, device, - elem: PhantomData, + dtype, } } @@ -123,7 +235,7 @@ where shape: self.shape.clone(), strides: self.strides.clone(), device, - elem: PhantomData, + dtype: self.dtype, } } @@ -134,12 +246,12 @@ where strides: &self.strides, shape: &self.shape.dims, runtime: PhantomData, - elem_size: E::dtype().size(), + elem_size: self.dtype.size(), } } /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg<'a, E: JitElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(); unsafe { @@ -173,12 +285,14 @@ where /// Copy the current tensor. pub fn copy(&self) -> Self { - unary_op!(numeric(self.clone()) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - input - } - execute::expand::(context, tensor) + execute_with_dtype!(self.dtype, E, { + unary_op!(numeric(self.clone()) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + input + } + execute::expand::(context, tensor) + }) }) } @@ -205,7 +319,7 @@ where /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory /// regions within the shape). pub fn is_contiguous_buffer(&self) -> bool { - self.shape.num_elements() * E::as_elem().size() == self.handle.size() as usize + self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize } } diff --git a/crates/burn-jit/src/tensor/qtensor.rs b/crates/burn-jit/src/tensor/qtensor.rs index 3b5d67c40a..fdf7068e1a 100644 --- a/crates/burn-jit/src/tensor/qtensor.rs +++ b/crates/burn-jit/src/tensor/qtensor.rs @@ -3,7 +3,7 @@ use burn_tensor::{ AffineQuantization, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, }, - read_sync, TensorData, + read_sync, DType, TensorData, TensorMetadata, }; use crate::{ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime}; @@ -12,17 +12,17 @@ use super::JitTensor; /// A quantized tensor primitive. #[derive(Debug)] -pub struct QJitTensor { +pub struct QJitTensor { /// The quantized tensor. /// Values are stored as multiple packed quantized values in u32. - pub qtensor: JitTensor, + pub qtensor: JitTensor, /// The quantization scheme. pub scheme: QuantizationScheme, /// The quantization parameters. - pub qparams: JitQuantizationParameters, + pub qparams: JitQuantizationParameters, } -impl QTensorPrimitive for QJitTensor { +impl QTensorPrimitive for QJitTensor { fn scheme(&self) -> &QuantizationScheme { &self.scheme } @@ -31,14 +31,15 @@ impl QTensorPrimitive for QJitTen match &self.scheme { QuantizationScheme::PerTensorAffine(dtype) => match dtype { QuantizationType::QInt8 => { - let scale = read_sync(into_data(self.qparams.scale.clone())) - .iter() - .next() - .unwrap(); - let offset = read_sync(into_data(self.qparams.offset.clone().unwrap())) + let scale = read_sync(into_data::(self.qparams.scale.clone())) .iter() .next() .unwrap(); + let offset = + read_sync(into_data::(self.qparams.offset.clone().unwrap())) + .iter() + .next() + .unwrap(); QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( scale, offset, )) @@ -46,7 +47,7 @@ impl QTensorPrimitive for QJitTen }, QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { QuantizationType::QInt8 => { - let scale = read_sync(into_data(self.qparams.scale.clone())) + let scale = read_sync(into_data::(self.qparams.scale.clone())) .iter() .next() .unwrap(); @@ -57,7 +58,7 @@ impl QTensorPrimitive for QJitTen } } -impl Clone for QJitTensor { +impl Clone for QJitTensor { fn clone(&self) -> Self { Self { qtensor: self.qtensor.clone(), @@ -67,16 +68,26 @@ impl Clone for QJitTensor TensorMetadata for QJitTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> burn_tensor::Shape { + self.qtensor.shape() + } +} + /// The quantization parameters. #[derive(Debug)] -pub struct JitQuantizationParameters { +pub struct JitQuantizationParameters { /// The scaling factor. - pub scale: JitTensor, + pub scale: JitTensor, /// The zero-point offset. - pub offset: Option>, + pub offset: Option>, } -impl Clone for JitQuantizationParameters { +impl Clone for JitQuantizationParameters { fn clone(&self) -> Self { Self { scale: self.scale.clone(), @@ -86,8 +97,7 @@ impl Clone for JitQuantizationPar } impl - From>> - for JitQuantizationParameters + From>> for JitQuantizationParameters { fn from(value: QuantizationParametersPrimitive>) -> Self { JitQuantizationParameters { @@ -97,11 +107,16 @@ impl } } -impl JitQuantizationParameters { - pub fn new(scale: F, offset: Option, device: &R::Device) -> Self { +impl JitQuantizationParameters { + pub fn new( + scale: F, + offset: Option, + device: &R::Device, + ) -> Self { Self { - scale: crate::ops::from_data(TensorData::new(vec![scale], [1]), device), - offset: offset.map(|o| crate::ops::from_data(TensorData::new(vec![o], [1]), device)), + scale: crate::ops::from_data::(TensorData::new(vec![scale], [1]), device), + offset: offset + .map(|o| crate::ops::from_data::(TensorData::new(vec![o], [1]), device)), } } } diff --git a/crates/burn-jit/src/tests/mask_fill.rs b/crates/burn-jit/src/tests/mask_fill.rs index b13a4d29ca..4542bbe3f1 100644 --- a/crates/burn-jit/src/tests/mask_fill.rs +++ b/crates/burn-jit/src/tests/mask_fill.rs @@ -2,13 +2,16 @@ mod tests { use super::*; use burn_jit::kernel::{mask_fill, MaskFillStrategy}; - use burn_tensor::{Bool, Distribution, Tensor, TensorPrimitive}; + use burn_tensor::{backend::Backend, Bool, Distribution, Tensor, TensorPrimitive}; #[test] fn mask_fill_should_match_reference_backend() { let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); - let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill( + let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::< + _, + ::FloatElem, + >( tensor.into_primitive().tensor(), mask.into_primitive(), 4.0, @@ -25,7 +28,10 @@ mod tests { fn mask_fill_inplace_should_match_reference_backend() { let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); - let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill( + let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::< + _, + ::FloatElem, + >( tensor.into_primitive().tensor(), mask.into_primitive(), 4.0, diff --git a/crates/burn-jit/src/tests/mask_where.rs b/crates/burn-jit/src/tests/mask_where.rs index e49c3ef0a4..befdb76af6 100644 --- a/crates/burn-jit/src/tests/mask_where.rs +++ b/crates/burn-jit/src/tests/mask_where.rs @@ -19,12 +19,16 @@ mod tests { fn mask_where_inplace_lhs_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_where( - tensor.into_primitive().tensor(), - mask.into_primitive(), - value.into_primitive().tensor(), - MaskWhereStrategy::InplaceLhs, - ))); + let actual = + Tensor::::from_primitive(TensorPrimitive::Float(mask_where::< + _, + ::FloatElem, + >( + tensor.into_primitive().tensor(), + mask.into_primitive(), + value.into_primitive().tensor(), + MaskWhereStrategy::InplaceLhs, + ))); let expected = tensor_ref.mask_where(mask_ref, value_ref); expected @@ -36,12 +40,16 @@ mod tests { fn mask_where_inplace_rhs_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); - let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_where( - tensor.into_primitive().tensor(), - mask.into_primitive(), - value.into_primitive().tensor(), - MaskWhereStrategy::InplaceRhs, - ))); + let actual = + Tensor::::from_primitive(TensorPrimitive::Float(mask_where::< + _, + ::FloatElem, + >( + tensor.into_primitive().tensor(), + mask.into_primitive(), + value.into_primitive().tensor(), + MaskWhereStrategy::InplaceRhs, + ))); let expected = tensor_ref.mask_where(mask_ref, value_ref); expected diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs index dae808796d..3e8f81fa8c 100644 --- a/crates/burn-jit/src/tests/reduce.rs +++ b/crates/burn-jit/src/tests/reduce.rs @@ -449,7 +449,10 @@ mod reduction { let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum( + let val = Tensor::::from_primitive(TensorPrimitive::Float(sum::< + _, + ::FloatElem, + >( tensor.into_primitive().tensor(), ReduceStrategy::default(), ))); @@ -465,7 +468,10 @@ mod reduction { let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod( + let val = Tensor::::from_primitive(TensorPrimitive::Float(prod::< + _, + ::FloatElem, + >( tensor.into_primitive().tensor(), ReduceStrategy::default(), ))); diff --git a/crates/burn-ndarray/src/ops/adaptive_avgpool.rs b/crates/burn-ndarray/src/ops/adaptive_avgpool.rs index b0f926ebda..1aa3f78844 100644 --- a/crates/burn-ndarray/src/ops/adaptive_avgpool.rs +++ b/crates/burn-ndarray/src/ops/adaptive_avgpool.rs @@ -1,6 +1,6 @@ use crate::{element::FloatNdArrayElement, sharing::UnsafeSharedRef, tensor::NdArrayTensor}; use burn_common::{iter_range_par, run_par}; -use burn_tensor::ElementConversion; +use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::Array4; #[cfg(not(feature = "std"))] diff --git a/crates/burn-ndarray/src/ops/avgpool.rs b/crates/burn-ndarray/src/ops/avgpool.rs index 0e31834cff..1b10012b05 100644 --- a/crates/burn-ndarray/src/ops/avgpool.rs +++ b/crates/burn-ndarray/src/ops/avgpool.rs @@ -1,7 +1,7 @@ use crate::{element::FloatNdArrayElement, sharing::UnsafeSharedRef, tensor::NdArrayTensor}; use burn_common::{iter_range_par, run_par}; -use burn_tensor::ElementConversion; +use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::Array4; pub(crate) fn avg_pool2d( diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 92e85b7387..104b973018 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -1,6 +1,7 @@ use alloc::{vec, vec::Vec}; use burn_tensor::ElementConversion; use burn_tensor::TensorData; +use burn_tensor::TensorMetadata; use core::fmt::Debug; use core::{marker::PhantomData, ops::Range}; use ndarray::s; @@ -34,7 +35,7 @@ pub(crate) struct NdArrayMathOps { impl NdArrayOps where - E: Copy + Debug, + E: Copy + Debug + burn_tensor::Element, { pub fn slice(tensor: NdArrayTensor, ranges: &[Range]) -> NdArrayTensor { let slices = Self::to_slice_args(ranges, tensor.shape().num_dims()); diff --git a/crates/burn-ndarray/src/ops/bool_tensor.rs b/crates/burn-ndarray/src/ops/bool_tensor.rs index 6dbd02e305..07b36c42c4 100644 --- a/crates/burn-ndarray/src/ops/bool_tensor.rs +++ b/crates/burn-ndarray/src/ops/bool_tensor.rs @@ -2,7 +2,7 @@ use alloc::vec; use alloc::vec::Vec; use burn_tensor::ops::{BoolTensorOps, IntTensorOps}; -use burn_tensor::ElementConversion; +use burn_tensor::{ElementConversion, TensorMetadata}; use core::ops::Range; use ndarray::{IntoDimension, Zip}; @@ -23,10 +23,6 @@ impl BoolTensorOp NdArrayTensor::from_data(data) } - fn bool_shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - async fn bool_into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); diff --git a/crates/burn-ndarray/src/ops/conv.rs b/crates/burn-ndarray/src/ops/conv.rs index ff84834735..429618826a 100644 --- a/crates/burn-ndarray/src/ops/conv.rs +++ b/crates/burn-ndarray/src/ops/conv.rs @@ -4,7 +4,7 @@ use burn_tensor::{ conv::{calculate_conv_output_size, calculate_conv_transpose_output_size}, ConvOptions, ConvTransposeOptions, }, - ElementConversion, + ElementConversion, TensorMetadata, }; use ndarray::{ s, Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index 2a5f8d2db2..12f3aad1e0 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -1,5 +1,8 @@ use burn_common::{iter_par, run_par}; -use burn_tensor::ops::{conv::calculate_conv_output_size, DeformConvOptions}; +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, DeformConvOptions}, + TensorMetadata, +}; use core::ops::AddAssign; use ndarray::{ s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 91d4346d2f..cdbff9d91d 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -6,6 +6,7 @@ use burn_tensor::ops::IntTensorOps; use burn_tensor::Distribution; use burn_tensor::ElementConversion; +use burn_tensor::TensorMetadata; use core::ops::Range; use ndarray::IntoDimension; use ndarray::Zip; @@ -29,10 +30,6 @@ impl IntTensorOps NdArrayTensor::from_data(data) } - fn int_shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - async fn int_into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); diff --git a/crates/burn-ndarray/src/ops/interpolate.rs b/crates/burn-ndarray/src/ops/interpolate.rs index 2a04291d32..b93c021da9 100644 --- a/crates/burn-ndarray/src/ops/interpolate.rs +++ b/crates/burn-ndarray/src/ops/interpolate.rs @@ -1,5 +1,5 @@ use burn_common::{iter_range_par, run_par}; -use burn_tensor::ElementConversion; +use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::Array4; #[cfg(not(feature = "std"))] use num_traits::Float; diff --git a/crates/burn-ndarray/src/ops/matmul.rs b/crates/burn-ndarray/src/ops/matmul.rs index 59ce40c6c1..f15a4a0f80 100644 --- a/crates/burn-ndarray/src/ops/matmul.rs +++ b/crates/burn-ndarray/src/ops/matmul.rs @@ -2,8 +2,8 @@ use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray, Unsafe use alloc::{vec, vec::Vec}; use burn_common::{iter_range_par, run_par}; -use burn_tensor::ElementConversion; use burn_tensor::{ops::FloatTensorOps, Shape}; +use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::s; pub(crate) fn matmul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor diff --git a/crates/burn-ndarray/src/ops/maxpool.rs b/crates/burn-ndarray/src/ops/maxpool.rs index 09db5fec2a..90ffe30a95 100644 --- a/crates/burn-ndarray/src/ops/maxpool.rs +++ b/crates/burn-ndarray/src/ops/maxpool.rs @@ -6,7 +6,7 @@ use crate::{ }; use burn_common::{iter_range_par, run_par}; -use burn_tensor::ElementConversion; +use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::Array4; pub(crate) fn max_pool2d( diff --git a/crates/burn-ndarray/src/ops/padding.rs b/crates/burn-ndarray/src/ops/padding.rs index c5879c045d..bd06a1f2d1 100644 --- a/crates/burn-ndarray/src/ops/padding.rs +++ b/crates/burn-ndarray/src/ops/padding.rs @@ -3,7 +3,7 @@ use crate::{ tensor::NdArrayTensor, NdArray, }; -use burn_tensor::ops::FloatTensorOps; +use burn_tensor::{ops::FloatTensorOps, TensorMetadata}; use ndarray::{Array4, Array5}; pub(crate) fn apply_padding_4d( diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index fab7b0c5bb..dd83d7e69f 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -6,7 +6,7 @@ use burn_tensor::{ AffineQuantization, QParams, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, }, - DType, Shape, TensorData, + DType, Shape, TensorData, TensorMetadata, }; use crate::{ diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index d03b58a200..9ffc603da4 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -11,8 +11,8 @@ use crate::{NdArrayDevice, SEED}; // Workspace crates use burn_common::rand::get_seeded_rng; -use burn_tensor::Distribution; use burn_tensor::{backend::Backend, ops::FloatTensorOps, ElementConversion, Shape, TensorData}; +use burn_tensor::{Distribution, TensorMetadata}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] @@ -62,10 +62,6 @@ impl FloatTensorO tensor } - fn float_shape(tensor: &NdArrayTensor) -> Shape { - tensor.shape() - } - async fn float_into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 56adbf14f6..1846f314b1 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -3,7 +3,7 @@ use burn_tensor::{ AffineQuantization, QParams, QTensorPrimitive, QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, }, - Element, Shape, TensorData, + DType, Element, Shape, TensorData, TensorMetadata, }; use ndarray::{ArcArray, Array, Dim, IxDyn}; @@ -17,8 +17,12 @@ pub struct NdArrayTensor { pub array: ArcArray, } -impl NdArrayTensor { - pub(crate) fn shape(&self) -> Shape { +impl TensorMetadata for NdArrayTensor { + fn dtype(&self) -> DType { + E::dtype() + } + + fn shape(&self) -> Shape { Shape::from(self.array.shape().to_vec()) } } @@ -211,6 +215,16 @@ impl QTensorPrimitive for NdArrayQTensor { } } +impl TensorMetadata for NdArrayQTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } +} + #[cfg(test)] mod tests { use crate::NdArray; diff --git a/crates/burn-remote/src/client/runner.rs b/crates/burn-remote/src/client/runner.rs index 97c031cb2b..6fed1c22be 100644 --- a/crates/burn-remote/src/client/runner.rs +++ b/crates/burn-remote/src/client/runner.rs @@ -60,7 +60,7 @@ impl RunnerClient for WsClient { fn register_float_tensor( &self, shape: Vec, - _full_precision: bool, + _dtype: burn_tensor::FloatDType, ) -> RouterTensor { self.register_empty_tensor(shape, DType::F32) } diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index 86338f9b57..ef5c4baa38 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -6,7 +6,7 @@ use burn_tensor::{ ops::FloatTensor, quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, repr::{BaseOperationDescription, OperationDescription, UnaryOperationDescription}, - Device, + Device, Element, }; use super::{get_client, set_seed, RouterTensor, RunnerChannel, RunnerClient}; @@ -90,7 +90,10 @@ impl BackendBridge> for PrecisionBridge { _device: Option>, ) -> FloatTensor { let client = tensor.client.clone(); - let out = client.register_float_tensor(tensor.shape.clone(), true); + let out = client.register_float_tensor( + tensor.shape.clone(), + ::FloatElem::dtype().into(), + ); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -109,7 +112,7 @@ impl BackendBridge> for PrecisionBridge { _device: Option>>, ) -> FloatTensor> { let client = tensor.client.clone(); - let out = client.register_float_tensor(tensor.shape.clone(), false); + let out = client.register_float_tensor(tensor.shape.clone(), R::FloatElem::dtype().into()); let desc = UnaryOperationDescription { input: tensor.into_description(), diff --git a/crates/burn-router/src/client/base.rs b/crates/burn-router/src/client/base.rs index 35723000b3..b131195e33 100644 --- a/crates/burn-router/src/client/base.rs +++ b/crates/burn-router/src/client/base.rs @@ -11,7 +11,7 @@ use spin::Mutex; use burn_tensor::{ backend::{DeviceId, DeviceOps}, repr::{OperationDescription, TensorDescription, TensorId}, - DType, TensorData, + DType, FloatDType, TensorData, }; use crate::{RouterTensor, RunnerChannel}; @@ -37,7 +37,7 @@ pub trait RunnerClient: Clone + Send + Sync + Sized { /// Create a new [RouterTensor] with no resources associated. fn register_empty_tensor(&self, shape: Vec, dtype: DType) -> RouterTensor; /// Create a new float [RouterTensor] with no resources associated. - fn register_float_tensor(&self, shape: Vec, full_precision: bool) -> RouterTensor; + fn register_float_tensor(&self, shape: Vec, dtype: FloatDType) -> RouterTensor; /// Get the current device used by all operations handled by this client. fn device(&self) -> Self::Device; /// Drop the tensor with the given [tensor id](TensorId). diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index e582631302..25c46ae854 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -8,7 +8,7 @@ use burn_tensor::repr::{ ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; -use burn_tensor::{DType, Device, Element, Shape, TensorData}; +use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; @@ -25,10 +25,6 @@ impl BoolTensorOps for BackendRouter { out } - fn bool_shape(tensor: &BoolTensor) -> Shape { - Shape::from(tensor.shape.clone()) - } - async fn bool_into_data(tensor: BoolTensor) -> TensorData { tensor.into_data().await } diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 12ddfa1cb7..dda01990e0 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -16,7 +16,9 @@ use burn_tensor::repr::{ SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; -use burn_tensor::{DType, Device, Distribution, Element, ElementConversion, Shape, TensorData}; +use burn_tensor::{ + DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, +}; use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; @@ -93,10 +95,6 @@ impl FloatTensorOps for BackendRouter { out } - fn float_shape(tensor: &FloatTensor) -> Shape { - tensor.shape() - } - async fn float_into_data(tensor: FloatTensor) -> TensorData { tensor .into_data() @@ -1478,10 +1476,19 @@ impl FloatTensorOps for BackendRouter { out } - fn float_cast( - _tensor: FloatTensor, - _dtype: burn_tensor::FloatDType, - ) -> FloatTensor { - todo!() + fn float_cast(tensor: FloatTensor, dtype: burn_tensor::FloatDType) -> FloatTensor { + let client = tensor.client.clone(); + let out = client.register_float_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Cast(desc), + )); + + out } } diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index a4e4d575d0..db81602d4f 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -16,7 +16,9 @@ use burn_tensor::repr::{ SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; -use burn_tensor::{DType, Device, Distribution, Element, ElementConversion, Shape, TensorData}; +use burn_tensor::{ + DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, +}; use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; @@ -33,10 +35,6 @@ impl IntTensorOps for BackendRouter { out } - fn int_shape(tensor: &IntTensor) -> Shape { - tensor.shape() - } - async fn int_into_data(tensor: IntTensor) -> TensorData { tensor .into_data() diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index f0df2a025a..a11c568509 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -9,7 +9,7 @@ use burn_tensor::{ NumericOperationDescription, OperationDescription, ReprBackend, TensorDescription, TensorId, TensorStatus, }, - DType, Element, ElementConversion, Shape, TensorData, + DType, Element, ElementConversion, FloatDType, Shape, TensorData, }; use core::future::Future; @@ -158,14 +158,9 @@ where pub(crate) fn register_float_tensor_desc( &self, shape: Vec, - full_precision: bool, + dtype: FloatDType, ) -> TensorDescription { - let dtype = if full_precision { - as Backend>::FloatElem::dtype() - } else { - B::FloatElem::dtype() - }; - self.register_empty_tensor_desc(shape, dtype) + self.register_empty_tensor_desc(shape, dtype.into()) } } @@ -1249,8 +1244,8 @@ where RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) } - fn register_float_tensor(&self, shape: Vec, full_precision: bool) -> RouterTensor { - let desc = self.register_float_tensor_desc(shape, full_precision); + fn register_float_tensor(&self, shape: Vec, dtype: FloatDType) -> RouterTensor { + let desc = self.register_float_tensor_desc(shape, dtype); RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) } diff --git a/crates/burn-router/src/tensor.rs b/crates/burn-router/src/tensor.rs index 15faf90817..45d463cc6f 100644 --- a/crates/burn-router/src/tensor.rs +++ b/crates/burn-router/src/tensor.rs @@ -3,7 +3,7 @@ use alloc::{sync::Arc, vec::Vec}; use super::RunnerClient; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, - DType, Shape, TensorData, + DType, Shape, TensorData, TensorMetadata, }; /// Tensor primitive for the [router backend](crate::BackendRouter). @@ -20,6 +20,16 @@ pub struct RouterTensor { pub(crate) is_orphan: bool, } +impl TensorMetadata for RouterTensor { + fn dtype(&self) -> DType { + self.dtype + } + + fn shape(&self) -> Shape { + Shape::from(self.shape.clone()) + } +} + impl RouterTensor { /// Create a new router tensor. pub fn new(id: Arc, shape: Vec, dtype: DType, client: C) -> Self { @@ -65,10 +75,6 @@ impl RouterTensor { } } - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.shape.clone()) - } - pub(crate) fn status(&self) -> TensorStatus { if Arc::strong_count(&self.id) <= 1 { TensorStatus::ReadWrite diff --git a/crates/burn-router/src/types.rs b/crates/burn-router/src/types.rs index f5c552abf6..0f57f3b195 100644 --- a/crates/burn-router/src/types.rs +++ b/crates/burn-router/src/types.rs @@ -146,15 +146,15 @@ macro_rules! impl_multi_backend_types { } } - fn register_float_tensor(&self, shape: Vec, full_precision: bool) -> RouterTensor { + fn register_float_tensor(&self, shape: Vec, dtype: burn_tensor::FloatDType) -> RouterTensor { match self { Self::$DefaultBackend(runner) => { - let desc = runner.register_float_tensor_desc(shape, full_precision); + let desc = runner.register_float_tensor_desc(shape, dtype); RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) } $( Self::$OtherBackend(runner) => { - let desc = runner.register_float_tensor_desc(shape, full_precision); + let desc = runner.register_float_tensor_desc(shape, dtype); RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) } )+ diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 19459731b4..2e8da4c948 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -1,4 +1,4 @@ -use burn_tensor::Shape; +use burn_tensor::{Shape, TensorMetadata}; use tch::Scalar; use crate::{LibTorchDevice, TchShape, TchTensor}; diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index b31ef57356..66ba524d8e 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -1,6 +1,6 @@ use super::TchOps; use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor}; -use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData}; +use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData, TensorMetadata}; use std::ops::Range; impl BoolTensorOps for LibTorch { @@ -8,16 +8,12 @@ impl BoolTensorOps for LibTorch { TchTensor::from_data::(data, (*device).into()) } - fn bool_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { TchOps::repeat_dim(tensor, dim, times) } async fn bool_into_data(tensor: TchTensor) -> TensorData { - let shape = Self::bool_shape(&tensor); + let shape = tensor.shape(); let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); TensorData::new(values.unwrap(), shape) diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 5ddb9c21f2..0da31fe430 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -3,7 +3,7 @@ use std::ops::Range; use burn_tensor::{ backend::Backend, ops::{IntTensor, IntTensorOps}, - Distribution, Shape, TensorData, + Distribution, Shape, TensorData, TensorMetadata, }; use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor}; @@ -15,16 +15,12 @@ impl IntTensorOps for LibTorch { TchTensor::from_data::(data, (*device).into()) } - fn int_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { TchOps::repeat_dim(tensor, dim, times) } async fn int_into_data(tensor: TchTensor) -> TensorData { - let shape = Self::int_shape(&tensor); + let shape = tensor.shape(); let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); TensorData::new(values.unwrap(), shape) diff --git a/crates/burn-tch/src/ops/module.rs b/crates/burn-tch/src/ops/module.rs index 334e64183d..bb9b50bd6f 100644 --- a/crates/burn-tch/src/ops/module.rs +++ b/crates/burn-tch/src/ops/module.rs @@ -1,7 +1,11 @@ use crate::{element::TchElement, LibTorch, QuantElement, TchTensor}; -use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode, - InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, +use burn_tensor::{ + ops::{ + ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, + InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, + }, + TensorMetadata, }; impl ModuleOps for LibTorch { diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 1decd1bcdc..ec644d066c 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -6,7 +6,7 @@ use burn_tensor::{ QParams, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, }, - DType, Shape, TensorData, + DType, Shape, TensorData, TensorMetadata, }; use crate::{LibTorch, LibTorchDevice, QuantElement, TchElement, TchQTensor, TchShape, TchTensor}; diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 8f460fccaa..22d63b9d64 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -3,7 +3,7 @@ use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShap use burn_tensor::{ backend::Backend, ops::{FloatTensorOps, IntTensor}, - Distribution, ElementConversion, FloatDType, Shape, TensorData, + Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata, }; use half::{bf16, f16}; use std::ops::Range; @@ -60,12 +60,8 @@ impl FloatTensorOps for LibTorch { TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device))) } - fn float_shape(tensor: &TchTensor) -> Shape { - tensor.shape() - } - async fn float_into_data(tensor: TchTensor) -> TensorData { - let shape = Self::float_shape(&tensor); + let shape = tensor.shape(); let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()])); match tensor.tensor.kind() { tch::Kind::Half => { diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index c64c45a807..b634954c8e 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -4,7 +4,7 @@ use burn_tensor::{ AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, }, - Shape, TensorData, + DType, Shape, TensorData, TensorMetadata, }; use libc::c_void; use std::sync::Arc; @@ -70,6 +70,30 @@ pub struct TchTensor { pub storage: Storage, } +impl TensorMetadata for TchTensor { + fn dtype(&self) -> DType { + match self.tensor.kind() { + tch::Kind::Uint8 => DType::U8, + tch::Kind::Int8 => DType::I8, + tch::Kind::Int16 => DType::I16, + tch::Kind::Int => DType::I32, + tch::Kind::Int64 => DType::I64, + tch::Kind::Half => DType::F16, + tch::Kind::Float => DType::F32, + tch::Kind::Double => DType::F64, + tch::Kind::Bool => DType::Bool, + tch::Kind::QUInt8 => DType::U8, + tch::Kind::BFloat16 => DType::BF16, + // Complex and quantization types are not valid/implemented. + _ => unimplemented!(), + } + } + + fn shape(&self) -> Shape { + Shape::from(self.tensor.size()) + } +} + impl TchTensor { /// Create a new tensor. /// @@ -133,12 +157,6 @@ impl TchTensor { } } -impl TchTensor { - pub(crate) fn shape(&self) -> Shape { - Shape::from(self.tensor.size()) - } -} - // This is safe since we don't use autodiff from LibTorch. // Also, atomic reference counting is used to know if the tensor's data can be reused. // If there are multiple reference on the same tensor, it becomes read only. @@ -310,6 +328,16 @@ pub struct TchQTensor { pub scheme: QuantizationScheme, } +impl TensorMetadata for TchQTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } +} + impl QTensorPrimitive for TchQTensor { fn scheme(&self) -> &QuantizationScheme { &self.scheme diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index c16c732caa..d634ee375c 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -21,7 +21,7 @@ use crate::{ }; use crate::{DType, Element, TensorPrimitive}; -use super::Transaction; +use super::{TensorMetadata, Transaction}; /// A tensor with a given backend, shape and data type. /// @@ -161,7 +161,7 @@ where /// } /// ``` pub fn shape(&self) -> Shape { - K::shape(&self.primitive) + self.primitive.shape() } /// Reshape the tensor to have the given shape. @@ -1873,7 +1873,7 @@ where writeln!(f, " device: {:?},", self.device())?; writeln!(f, " backend: {:?},", B::name())?; writeln!(f, " kind: {:?},", K::name())?; - writeln!(f, " dtype: {:?},", K::elem_type_name())?; + writeln!(f, " dtype: {:?},", self.primitive.dtype().name())?; write!(f, "}}") } } @@ -1928,26 +1928,6 @@ pub trait BasicOps: TensorKind { /// which is more high-level and designed for public use. fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, - /// which is more high-level and designed for public use. - fn shape(tensor: &Self::Primitive) -> Shape; - /// Reshapes the tensor. /// /// # Arguments @@ -2318,6 +2298,11 @@ pub trait BasicOps: TensorKind { core::any::type_name::() } + /// Returns the tensor data type. + fn dtype(tensor: &Self::Primitive) -> DType { + tensor.dtype() + } + /// Tests if any element in the `tensor` evaluates to True. /// /// # Arguments @@ -2417,13 +2402,6 @@ impl BasicOps for Float { tr.register_float(tensor); } - fn shape(tensor: &Self::Primitive) -> Shape { - match tensor { - TensorPrimitive::Float(tensor) => B::float_shape(tensor), - TensorPrimitive::QFloat(tensor) => B::q_shape(tensor), - } - } - fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { @@ -2640,10 +2618,6 @@ impl BasicOps for Int { tr.register_int(tensor); } - fn shape(tensor: &Self::Primitive) -> Shape { - B::int_shape(tensor) - } - fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { B::int_reshape(tensor, shape) } @@ -2756,10 +2730,6 @@ impl BasicOps for Bool { tr.register_bool(tensor); } - fn shape(tensor: &Self::Primitive) -> Shape { - B::bool_shape(tensor) - } - fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { B::bool_reshape(tensor, shape) } diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index c3735f8a57..524e68bfa9 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -1,4 +1,4 @@ -use super::narrow::narrow; +use super::{narrow::narrow, TensorMetadata}; use crate::{backend::Backend, BasicOps, TensorKind}; use alloc::vec::Vec; @@ -25,7 +25,7 @@ pub fn chunk + BasicOps>( chunks: usize, dim: usize, ) -> Vec { - let size = K::shape(&tensor).dims[dim]; + let size = tensor.shape().dims[dim]; if size < chunks { return (0..size) .map(|i| narrow::(tensor.clone(), dim, i, 1)) diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 2242c5d66e..870f3819ce 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -244,6 +244,10 @@ where } /// Converts a tensor to the specified floating point data type. + /// + /// # Warning + /// Most backends don't have automatic type promotion at this time, so make sure that all tensors + /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops). pub fn cast>(self, dtype: F) -> Tensor { Tensor::new(TensorPrimitive::Float(B::float_cast( self.primitive.tensor(), diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index f1aed5e9c6..0930d10cf6 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -1,4 +1,4 @@ -use crate::backend::Backend; +use crate::{backend::Backend, DType, Shape}; /// A type-level representation of the kind of a float tensor #[derive(Clone, Debug)] @@ -31,10 +31,35 @@ impl TensorPrimitive { } } +impl TensorMetadata for TensorPrimitive { + fn dtype(&self) -> DType { + match self { + TensorPrimitive::Float(tensor) => tensor.dtype(), + TensorPrimitive::QFloat(tensor) => tensor.dtype(), + } + } + + fn shape(&self) -> Shape { + match self { + TensorPrimitive::Float(tensor) => tensor.shape(), + TensorPrimitive::QFloat(tensor) => tensor.shape(), + } + } +} + +/// Tensor metadata trait for tensor primitive. +pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug { + /// The dtype of the tensor. + fn dtype(&self) -> DType; + /// The shape of the tensor. + fn shape(&self) -> Shape; +} + /// A type-level representation of the kind of a tensor. +/// Metadata access is lazy. pub trait TensorKind: Clone + core::fmt::Debug { /// The primitive type of the tensor. - type Primitive: Clone + core::fmt::Debug + Send + Sync; + type Primitive: TensorMetadata; /// The name of the tensor kind. fn name() -> &'static str; diff --git a/crates/burn-tensor/src/tensor/api/narrow.rs b/crates/burn-tensor/src/tensor/api/narrow.rs index 910c781411..b4a7d04aa3 100644 --- a/crates/burn-tensor/src/tensor/api/narrow.rs +++ b/crates/burn-tensor/src/tensor/api/narrow.rs @@ -1,6 +1,8 @@ use crate::{backend::Backend, BasicOps, TensorKind}; use alloc::vec::Vec; +use super::TensorMetadata; + /// Returns a new tensor with the given dimension narrowed to the given range. /// /// # Arguments @@ -23,7 +25,7 @@ pub fn narrow + BasicOps>( start: usize, length: usize, ) -> K::Primitive { - let shape = K::shape(&tensor); + let shape = tensor.shape(); let ranges: Vec<_> = shape .dims diff --git a/crates/burn-tensor/src/tensor/api/split.rs b/crates/burn-tensor/src/tensor/api/split.rs index b316faa334..7a4b38b80c 100644 --- a/crates/burn-tensor/src/tensor/api/split.rs +++ b/crates/burn-tensor/src/tensor/api/split.rs @@ -1,4 +1,4 @@ -use super::narrow::narrow; +use super::{narrow::narrow, TensorMetadata}; use crate::{backend::Backend, BasicOps, TensorKind}; use alloc::vec::Vec; @@ -27,7 +27,7 @@ pub fn split + BasicOps>( split_size: usize, dim: usize, ) -> Vec { - let size = K::shape(&tensor).dims[dim]; + let size = tensor.shape().dims[dim]; let mut tensors = Vec::new(); let mut start = 0; diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index 2e0c45f7cb..43c9f53214 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -1,6 +1,7 @@ use alloc::string::String; use crate::tensor::Element; +use crate::TensorMetadata; use crate::{ops::*, quantization::QTensorPrimitive}; use super::{BackendBridge, DeviceOps}; @@ -74,25 +75,20 @@ pub trait Backend: type FullPrecisionBridge: BackendBridge + 'static; /// Tensor primitive to be used for all float operations. - type FloatTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; - /// Float element type. + type FloatTensorPrimitive: TensorMetadata + 'static; + /// Default float element type. type FloatElem: Element; /// Tensor primitive to be used for all int operations. - type IntTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + type IntTensorPrimitive: TensorMetadata + 'static; /// Int element type. type IntElem: Element; /// Tensor primitive to be used for all bool operations. - type BoolTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; + type BoolTensorPrimitive: TensorMetadata + 'static; /// Tensor primitive to be used for all quantized operations. - type QuantizedTensorPrimitive: QTensorPrimitive - + Clone - + Send - + Sync - + 'static - + core::fmt::Debug; + type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static; /// Quantized tensor encoding type. type QuantizedEncoding: Element; diff --git a/crates/burn-tensor/src/tensor/element/base.rs b/crates/burn-tensor/src/tensor/element/base.rs index 1f3d2f1678..bb12a67f72 100644 --- a/crates/burn-tensor/src/tensor/element/base.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -324,6 +324,26 @@ impl DType { pub fn is_bool(&self) -> bool { matches!(self, DType::Bool) } + + /// Returns the data type name. + pub fn name(&self) -> &'static str { + match self { + DType::F64 => "f64", + DType::F32 => "f32", + DType::F16 => "f16", + DType::BF16 => "bf16", + DType::I64 => "i64", + DType::I32 => "i32", + DType::I16 => "i16", + DType::I8 => "i8", + DType::U64 => "u64", + DType::U32 => "u32", + DType::U16 => "u16", + DType::U8 => "u8", + DType::Bool => "bool", + DType::QFloat(_) => "qfloat", + } + } } #[allow(missing_docs)] @@ -346,3 +366,14 @@ impl From for FloatDType { } } } + +impl From for DType { + fn from(value: FloatDType) -> Self { + match value { + FloatDType::F64 => DType::F64, + FloatDType::F32 => DType::F32, + FloatDType::F16 => DType::F16, + FloatDType::BF16 => DType::BF16, + } + } +} diff --git a/crates/burn-tensor/src/tensor/ops/activation.rs b/crates/burn-tensor/src/tensor/ops/activation.rs index 6c59b143b2..a35208b99b 100644 --- a/crates/burn-tensor/src/tensor/ops/activation.rs +++ b/crates/burn-tensor/src/tensor/ops/activation.rs @@ -1,4 +1,5 @@ use crate::tensor::ops::tensor::FloatTensorOps; +use crate::TensorMetadata; use crate::{backend::Backend, ElementConversion}; use core::f64::consts::SQRT_2; @@ -259,7 +260,7 @@ pub trait ActivationOps { // -max_derive - (z-1)/z if x is >= 0 // -max_derive + (z-1)/z if x is < 0 - let shape = B::float_shape(&x); + let shape = x.shape(); let device = B::float_device(&x); // max(-x, 0) diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index 31ed6d1ede..c0ea76efa4 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{ argwhere_data, backend::Backend, chunk, narrow, split, split_with_sizes, tensor::Shape, Bool, - ElementConversion, TensorData, + ElementConversion, TensorData, TensorMetadata, }; use alloc::{vec, vec::Vec}; use core::{future::Future, ops::Range}; @@ -24,17 +24,6 @@ pub trait BoolTensorOps { /// The boolean tensor with the given shape. fn bool_empty(shape: Shape, device: &Device) -> BoolTensor; - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn bool_shape(tensor: &BoolTensor) -> Shape; - /// Converts the tensor to a data structure. /// /// # Arguments @@ -212,7 +201,7 @@ pub trait BoolTensorOps { /// /// The transposed tensor. fn bool_transpose(tensor: BoolTensor) -> BoolTensor { - let ndims = Self::bool_shape(&tensor).num_dims(); + let ndims = tensor.shape().num_dims(); Self::bool_swap_dims(tensor, ndims - 2, ndims - 1) } @@ -367,7 +356,7 @@ pub trait BoolTensorOps { /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn bool_all(tensor: BoolTensor) -> BoolTensor { - let num_elems = B::bool_shape(&tensor).num_elements(); + let num_elems = tensor.shape().num_elements(); let sum = B::int_sum(B::bool_into_int(tensor)); B::int_equal_elem(sum, (num_elems as i32).elem()) } @@ -386,7 +375,7 @@ pub trait BoolTensorOps { /// evaluates to True, False otherwise. fn bool_all_dim(tensor: BoolTensor, dim: usize) -> BoolTensor { - let num_elems = B::bool_shape(&tensor).dims[dim]; + let num_elems = tensor.shape().dims[dim]; let sum = B::int_sum_dim(B::bool_into_int(tensor), dim); B::int_equal_elem(sum, (num_elems as i32).elem()) } @@ -427,12 +416,12 @@ pub trait BoolTensorOps { async { let indices = B::bool_argwhere(tensor).await; - if B::int_shape(&indices).num_elements() == 0 { + if indices.shape().num_elements() == 0 { // Return empty vec when all elements are zero return vec![]; } - let dims = B::int_shape(&indices).dims; + let dims = indices.shape().dims; B::int_chunk(indices, dims[1], 1) .into_iter() .map(|t| B::int_reshape(t, Shape::new([dims[0]]))) diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index f62c06b467..abdd2e54ba 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -8,7 +8,7 @@ use alloc::vec::Vec; use core::future::Future; use core::ops::Range; -use crate::{argsort, sort, sort_with_indices}; +use crate::{argsort, sort, sort_with_indices, TensorMetadata}; /// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor) /// for documentation on each function. @@ -25,17 +25,6 @@ pub trait IntTensorOps { /// The integer tensor with the given shape. fn int_empty(shape: Shape, device: &Device) -> IntTensor; - /// Returns the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn int_shape(tensor: &IntTensor) -> Shape; - /// Converts the tensor to a data structure. /// /// # Arguments @@ -727,7 +716,7 @@ pub trait IntTensorOps { /// /// The mean of all elements in the tensor. fn int_mean(tensor: IntTensor) -> IntTensor { - let num_elems = B::int_shape(&tensor).num_elements(); + let num_elems = tensor.shape().num_elements(); B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) } @@ -776,7 +765,7 @@ pub trait IntTensorOps { /// /// The maximum element in the tensor. fn int_max(tensor: IntTensor) -> IntTensor { - let shape = B::int_shape(&tensor); + let shape = tensor.shape(); let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); B::int_max_dim(tensor, 0) @@ -794,7 +783,7 @@ pub trait IntTensorOps { /// The maximum element in the tensor along the dimension. fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { let index = B::int_argmax(tensor.clone(), dim); - let ndim = B::int_shape(&tensor).num_dims(); + let ndim = tensor.shape().num_dims(); B::int_gather(ndim - 1, tensor, index) } @@ -826,7 +815,7 @@ pub trait IntTensorOps { /// /// The minimum element in the tensor. fn int_min(tensor: IntTensor) -> IntTensor { - let shape = B::int_shape(&tensor); + let shape = tensor.shape(); let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); B::int_min_dim(tensor, 0) @@ -844,7 +833,7 @@ pub trait IntTensorOps { /// The minimum element in the tensor along the dimension. fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { let index = B::int_argmin(tensor.clone(), dim); - let ndim = B::int_shape(&tensor).num_dims(); + let ndim = tensor.shape().num_dims(); B::int_gather(ndim - 1, tensor, index) } @@ -861,7 +850,7 @@ pub trait IntTensorOps { /// The minimum elements and corresponding indices along the dimension. fn int_min_dim_with_indices(tensor: IntTensor, dim: usize) -> (IntTensor, IntTensor) { let indices = B::int_argmin(tensor.clone(), dim); - let ndim = B::int_shape(&tensor).num_dims(); + let ndim = tensor.shape().num_dims(); let values = B::int_gather(ndim - 1, tensor, indices.clone()); (values, indices) @@ -888,7 +877,7 @@ pub trait IntTensorOps { /// /// The transposed tensor. fn int_transpose(tensor: IntTensor) -> IntTensor { - let ndims = Self::int_shape(&tensor).num_dims(); + let ndims = tensor.shape().num_dims(); Self::int_swap_dims(tensor, ndims - 2, ndims - 1) } @@ -1092,7 +1081,7 @@ pub trait IntTensorOps { /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn int_all(tensor: IntTensor) -> BoolTensor { - let num_elems = B::int_shape(&tensor).num_elements(); + let num_elems = tensor.shape().num_elements(); let bool_tensor = B::int_equal_elem(tensor, 0.elem()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::int_sum(B::bool_into_int(bool_tensor)); @@ -1112,7 +1101,7 @@ pub trait IntTensorOps { /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. fn int_all_dim(tensor: IntTensor, dim: usize) -> BoolTensor { - let num_elems = B::int_shape(&tensor).dims[dim]; + let num_elems = tensor.shape().dims[dim]; let bool_tensor = B::int_equal_elem(tensor, 0.elem()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim); @@ -1129,7 +1118,7 @@ pub trait IntTensorOps { /// /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`. fn int_sign(tensor: IntTensor) -> IntTensor { - let zeros = B::int_zeros(B::int_shape(&tensor), &B::int_device(&tensor)); + let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor)); let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem()); let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem()); diff --git a/crates/burn-tensor/src/tensor/ops/modules/base.rs b/crates/burn-tensor/src/tensor/ops/modules/base.rs index 38f9f3554b..d28de15662 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/base.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/base.rs @@ -2,7 +2,7 @@ use super::{conv, pool, unfold::unfold4d_using_conv2d}; use crate::{ backend::Backend, ops::{FloatTensor, IntTensor}, - Shape, + Shape, TensorMetadata, }; /// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). @@ -195,8 +195,8 @@ pub trait ModuleOps { /// /// The output tensor. fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { - let [batch_size, seq_length] = B::int_shape(&indices).dims(); - let [_, d_model] = B::float_shape(&weights).dims(); + let [batch_size, seq_length] = indices.shape().dims(); + let [_, d_model] = weights.shape().dims(); let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); let output = B::float_select(weights, 0, indices); @@ -220,8 +220,8 @@ pub trait ModuleOps { output_grad: FloatTensor, indices: IntTensor, ) -> FloatTensor { - let [batch_size, seq_length] = B::int_shape(&indices).dims(); - let [n_embeddings, d_model] = B::float_shape(&weights).dims(); + let [batch_size, seq_length] = indices.shape().dims(); + let [n_embeddings, d_model] = weights.shape().dims(); let device = B::float_device(&weights); let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); diff --git a/crates/burn-tensor/src/tensor/ops/modules/cat.rs b/crates/burn-tensor/src/tensor/ops/modules/cat.rs index 64416cdcf8..edfbba2d3f 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/cat.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/cat.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, TensorKind}; +use crate::{backend::Backend, BasicOps, TensorKind, TensorMetadata}; use alloc::vec::Vec; pub(crate) fn cat_with_slice_assign + BasicOps>( @@ -6,13 +6,10 @@ pub(crate) fn cat_with_slice_assign + BasicOps>( dim: usize, ) -> K::Primitive { let first_tensor = tensors.first().expect("Tensors should not be empty"); - let mut shape = K::shape(first_tensor); + let mut shape = first_tensor.shape(); let device = K::device(first_tensor); - let output_dim_length: usize = tensors - .iter() - .map(|tensor| K::shape(tensor).dims[dim]) - .sum(); + let output_dim_length: usize = tensors.iter().map(|tensor| tensor.shape().dims[dim]).sum(); shape.dims[dim] = output_dim_length; let mut tensor_output = K::empty(shape.clone(), &device); @@ -22,7 +19,7 @@ pub(crate) fn cat_with_slice_assign + BasicOps>( let mut output_index = 0; for tensor in tensors { let mut indices = indices_select_all.clone(); - let tensor_dim_length = K::shape(&tensor).dims[dim]; + let tensor_dim_length = tensor.shape().dims[dim]; indices[dim] = output_index..output_index + tensor_dim_length; output_index += tensor_dim_length; diff --git a/crates/burn-tensor/src/tensor/ops/modules/conv.rs b/crates/burn-tensor/src/tensor/ops/modules/conv.rs index 9bdad3bea4..5cf949dff6 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/conv.rs @@ -1,6 +1,6 @@ #![allow(clippy::single_range_in_vec_init)] use super::{ConvOptions, ConvTransposeOptions}; -use crate::{backend::Backend, ops::FloatTensor, Shape}; +use crate::{backend::Backend, ops::FloatTensor, Shape, TensorMetadata}; #[cfg(not(feature = "std"))] use num_traits::Float; @@ -64,10 +64,10 @@ pub(crate) fn conv1d_x_backward( output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); - let [_batch_size, _, length_in] = B::float_shape(&x).dims(); - let [_batch_size, _channels_out, length_out] = B::float_shape(&output_grad).dims(); + let [_batch_size, _, length_in] = x.shape().dims(); + let [_batch_size, _channels_out, length_out] = output_grad.shape().dims(); let [_, _, kernel_size] = weight_shape.dims(); let padding_out = calculate_padding_out( @@ -100,7 +100,7 @@ pub(crate) fn conv1d_weight_backward( output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { @@ -120,14 +120,14 @@ pub(crate) fn conv1d_bias_backward( bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { - let [batch_size, _, _length_in] = B::float_shape(&x).dims(); - let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims(); + let [batch_size, _, _length_in] = x.shape().dims(); + let [_batch_size, channels_out, length_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); let grad = B::float_sum_dim(grad, 1); - B::float_reshape(grad, B::float_shape(&bias)) + B::float_reshape(grad, bias.shape()) } /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`. @@ -137,10 +137,10 @@ pub(crate) fn conv2d_x_backward( output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); - let [_batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims(); - let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims(); + let [_batch_size, _channels_in, height_in, width_in] = x.shape().dims(); + let [_, _, height_out, width_out] = output_grad.shape().dims(); let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims(); let padding_1_out = calculate_padding_out( @@ -181,7 +181,7 @@ pub(crate) fn conv2d_weight_backward( output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { @@ -202,10 +202,10 @@ pub(crate) fn conv2d_bias_backward( bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); - let [batch_size, _channels_in, _height_in, _width_in] = B::float_shape(&x).dims(); - let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims(); + let [batch_size, _channels_in, _height_in, _width_in] = x.shape().dims(); + let [_, _, height_out, width_out] = output_grad.shape().dims(); let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims(); let grad = B::float_swap_dims(output_grad, 0, 1); @@ -215,7 +215,7 @@ pub(crate) fn conv2d_bias_backward( ); let grad = B::float_sum_dim(grad, 1); - B::float_reshape(grad, B::float_shape(&bias)) + B::float_reshape(grad, bias.shape()) } /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`. @@ -225,10 +225,10 @@ pub(crate) fn conv3d_x_backward( output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); - let [_batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims(); - let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims(); + let [_batch_size, _channels_in, depth_in, height_in, width_in] = x.shape().dims(); + let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims(); let [_channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims(); let padding_1_out = calculate_padding_out( @@ -277,7 +277,7 @@ pub(crate) fn conv3d_weight_backward( output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { @@ -298,10 +298,10 @@ pub(crate) fn conv3d_bias_backward( bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); - let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = B::float_shape(&x).dims(); - let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims(); + let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = x.shape().dims(); + let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims(); let [channels_out, _, _kernel_size_1, _kernel_size_2, _kernel_size_3] = weight_shape.dims(); let grad = B::float_swap_dims(output_grad, 0, 1); @@ -314,7 +314,7 @@ pub(crate) fn conv3d_bias_backward( ); let grad = B::float_sum_dim(grad, 1); - B::float_reshape(grad, B::float_shape(&bias)) + B::float_reshape(grad, bias.shape()) } /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`. @@ -343,7 +343,7 @@ pub(crate) fn conv_transpose1d_weight_backward( output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { @@ -363,14 +363,14 @@ pub(crate) fn conv_transpose1d_bias_backward( bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { - let [batch_size, _channels_in, _] = B::float_shape(&x).dims(); - let [_, channels_out, length_out] = B::float_shape(&output_grad).dims(); + let [batch_size, _channels_in, _] = x.shape().dims(); + let [_, channels_out, length_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); let grad = B::float_sum_dim(grad, 1); - B::float_reshape(grad, B::float_shape(&bias)) + B::float_reshape(grad, bias.shape()) } /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`. @@ -399,7 +399,7 @@ pub(crate) fn conv_transpose2d_weight_backward( output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { @@ -419,8 +419,8 @@ pub(crate) fn conv_transpose2d_bias_backward( bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { - let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims(); - let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims(); + let [batch_size, _channels_in, _, _] = x.shape().dims(); + let [_, channels_out, height_out, width_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape( @@ -429,7 +429,7 @@ pub(crate) fn conv_transpose2d_bias_backward( ); let grad = B::float_sum_dim(grad, 1); - B::float_reshape(grad, B::float_shape(&bias)) + B::float_reshape(grad, bias.shape()) } /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`. @@ -458,7 +458,7 @@ pub(crate) fn conv_transpose3d_weight_backward( output_grad: FloatTensor, options: ConvTransposeOptions<3>, ) -> FloatTensor { - let weight_shape = B::float_shape(&weight); + let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { @@ -478,8 +478,8 @@ pub(crate) fn conv_transpose3d_bias_backward( bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { - let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims(); - let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims(); + let [batch_size, _channels_in, _, _, _] = x.shape().dims(); + let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape( @@ -491,7 +491,7 @@ pub(crate) fn conv_transpose3d_bias_backward( ); let grad = B::float_sum_dim(grad, 1); - B::float_reshape(grad, B::float_shape(&bias)) + B::float_reshape(grad, bias.shape()) } /// Execute a 1D convolution using a 2D convolution. @@ -501,8 +501,8 @@ pub(crate) fn conv1d_from_conv2d( bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { - let [channels_out, _channels_in, kernel_size] = B::float_shape(&weight).dims(); - let [batch_size, channels_in, length_in] = B::float_shape(&x).dims(); + let [channels_out, _channels_in, kernel_size] = weight.shape().dims(); + let [batch_size, channels_in, length_in] = x.shape().dims(); let weight = B::float_reshape( weight, @@ -521,7 +521,7 @@ pub(crate) fn conv1d_from_conv2d( options.groups, ), ); - let [batch_size, channels_out, height_out, _weight_out] = B::float_shape(&tensor).dims(); + let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } @@ -532,8 +532,8 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d( bias: Option>, options: ConvTransposeOptions<1>, ) -> FloatTensor { - let [channels_in, channels_out, kernel_size] = B::float_shape(&weight).dims(); - let [batch_size, _channels_in, length_in] = B::float_shape(&x).dims(); + let [channels_in, channels_out, kernel_size] = weight.shape().dims(); + let [batch_size, _channels_in, length_in] = x.shape().dims(); let weight = B::float_reshape( weight, @@ -553,7 +553,7 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d( options.groups, ), ); - let [batch_size, channels_out, height_out, _weight_out] = B::float_shape(&tensor).dims(); + let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } @@ -573,7 +573,7 @@ fn conv1d_weight_grad_no_groups( ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); - if B::float_shape(&weight_grad) != weight_shape { + if weight_grad.shape() != weight_shape { weight_grad = B::float_slice( weight_grad, &[ @@ -602,7 +602,7 @@ fn conv2d_weight_grad_no_groups( ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); - if B::float_shape(&weight_grad) != weight_shape { + if weight_grad.shape() != weight_shape { weight_grad = B::float_slice( weight_grad, &[ @@ -632,7 +632,7 @@ fn conv3d_weight_grad_no_groups( ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); - if B::float_shape(&weight_grad) != weight_shape { + if weight_grad.shape() != weight_shape { weight_grad = B::float_slice( weight_grad, &[ @@ -653,7 +653,7 @@ fn conv1d_weight_grad_groups( output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { - let [channels_out, increment_ci, kernel_size] = B::float_shape(&weight_grad).dims(); + let [channels_out, increment_ci, kernel_size] = weight_grad.shape().dims(); let increment_co = channels_out / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); @@ -690,8 +690,7 @@ fn conv2d_weight_grad_groups( output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { - let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = - B::float_shape(&weight_grad).dims(); + let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = weight_grad.shape().dims(); let increment_co = channels_out / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); @@ -712,7 +711,7 @@ fn conv2d_weight_grad_groups( ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::float_shape(&weight_grad_tmp).dims(); + let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { weight_grad_tmp = B::float_slice( @@ -748,7 +747,7 @@ fn conv3d_weight_grad_groups( options: ConvOptions<3>, ) -> FloatTensor { let [channels_out, increment_ci, kernel_size_1, kernel_size_2, kernel_size_3] = - B::float_shape(&weight_grad).dims(); + weight_grad.shape().dims(); let increment_co = channels_out / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); @@ -770,7 +769,7 @@ fn conv3d_weight_grad_groups( ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [_, _, kernel_size_1_tmp, kernel_size_2_tmp, kernel_size_3_tmp] = - B::float_shape(&weight_grad_tmp).dims(); + weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 @@ -820,7 +819,7 @@ fn conv_transpose1d_weight_grad_no_groups( ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); - let grad_shape = B::float_shape(&weight_grad); + let grad_shape = weight_grad.shape(); if grad_shape != weight_shape { weight_grad = B::float_slice( weight_grad, @@ -850,7 +849,7 @@ fn conv_transpose2d_weight_grad_no_groups( ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); - let grad_shape = B::float_shape(&weight_grad); + let grad_shape = weight_grad.shape(); if grad_shape != weight_shape { weight_grad = B::float_slice( weight_grad, @@ -881,7 +880,7 @@ fn conv_transpose3d_weight_grad_no_groups( ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); - let grad_shape = B::float_shape(&weight_grad); + let grad_shape = weight_grad.shape(); if grad_shape != weight_shape { weight_grad = B::float_slice( weight_grad, @@ -903,7 +902,7 @@ fn conv_transpose1d_weight_grad_groups( output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { - let [channels_in, increment_co, kernel_size] = B::float_shape(&weight_grad).dims(); + let [channels_in, increment_co, kernel_size] = weight_grad.shape().dims(); let increment_ci = channels_in / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); @@ -924,7 +923,7 @@ fn conv_transpose1d_weight_grad_groups( ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_tmp] = B::float_shape(&weight_grad_tmp).dims(); + let [_, _, kernel_size_tmp] = weight_grad_tmp.shape().dims(); if kernel_size_tmp != kernel_size { weight_grad_tmp = B::float_slice( @@ -949,8 +948,7 @@ fn conv_transpose2d_weight_grad_groups( output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { - let [channels_in, increment_co, kernel_size_1, kernel_size_2] = - B::float_shape(&weight_grad).dims(); + let [channels_in, increment_co, kernel_size_1, kernel_size_2] = weight_grad.shape().dims(); let increment_ci = channels_in / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); @@ -971,7 +969,7 @@ fn conv_transpose2d_weight_grad_groups( ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); - let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::float_shape(&weight_grad_tmp).dims(); + let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { weight_grad_tmp = B::float_slice( @@ -1007,7 +1005,7 @@ fn conv_transpose3d_weight_grad_groups( options: ConvTransposeOptions<3>, ) -> FloatTensor { let [channels_in, increment_co, kernel_size_1, kernel_size_2, kernel_size_3] = - B::float_shape(&weight_grad).dims(); + weight_grad.shape().dims(); let increment_ci = channels_in / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); @@ -1029,7 +1027,7 @@ fn conv_transpose3d_weight_grad_groups( ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [_, _, kernel_size_1_tmp, kernel_size_2_tmp, kernel_size_3_tmp] = - B::float_shape(&weight_grad_tmp).dims(); + weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 diff --git a/crates/burn-tensor/src/tensor/ops/modules/pool.rs b/crates/burn-tensor/src/tensor/ops/modules/pool.rs index 59d74c1fad..48feb1e7cd 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/pool.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/pool.rs @@ -1,7 +1,7 @@ use crate::{ backend::Backend, ops::{FloatTensor, IntTensor}, - Shape, + Shape, TensorMetadata, }; use super::{MaxPool1dBackward, MaxPool1dWithIndices}; @@ -13,7 +13,7 @@ pub(crate) fn avg_pool1d_from_2d( padding: usize, count_include_pad: bool, ) -> FloatTensor { - let [batch_size, channels, length] = B::float_shape(&x).dims(); + let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); let x = B::avg_pool2d( @@ -24,7 +24,7 @@ pub(crate) fn avg_pool1d_from_2d( count_include_pad, ); - let [batch_size, channels, length, _] = B::float_shape(&x).dims(); + let [batch_size, channels, length, _] = x.shape().dims(); B::float_reshape(x, Shape::from([batch_size, channels, length])) } @@ -37,8 +37,8 @@ pub(crate) fn avg_pool1d_backward_from_2d( padding: usize, count_include_pad: bool, ) -> FloatTensor { - let [batch_size, channels, length_in] = B::float_shape(&x).dims(); - let [_, _, length_out] = B::float_shape(&grad).dims(); + let [batch_size, channels, length_in] = x.shape().dims(); + let [_, _, length_out] = grad.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1])); @@ -59,12 +59,12 @@ pub(crate) fn adaptive_avg_pool1d_from_2d( x: FloatTensor, output_size: usize, ) -> FloatTensor { - let [batch_size, channels, length] = B::float_shape(&x).dims(); + let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); let x = B::adaptive_avg_pool2d(x, [output_size, 1]); - let [batch_size, channels, length, _] = B::float_shape(&x).dims(); + let [batch_size, channels, length, _] = x.shape().dims(); B::float_reshape(x, Shape::from([batch_size, channels, length])) } @@ -73,8 +73,8 @@ pub(crate) fn adaptive_avg_pool1d_backward_from_2d( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { - let [batch_size, channels, length_in] = B::float_shape(&x).dims(); - let [_, _, length_out] = B::float_shape(&grad).dims(); + let [batch_size, channels, length_in] = x.shape().dims(); + let [_, _, length_out] = grad.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1])); @@ -91,7 +91,7 @@ pub(crate) fn max_pool1d_from_2d( padding: usize, dilation: usize, ) -> FloatTensor { - let [batch_size, channels, length] = B::float_shape(&x).dims(); + let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); let x = B::max_pool2d( @@ -102,7 +102,7 @@ pub(crate) fn max_pool1d_from_2d( [dilation, 1], ); - let [batch_size, channels, length, _] = B::float_shape(&x).dims(); + let [batch_size, channels, length, _] = x.shape().dims(); B::float_reshape(x, Shape::from([batch_size, channels, length])) } @@ -114,7 +114,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d( padding: usize, dilation: usize, ) -> MaxPool1dWithIndices { - let [batch_size, channels, length] = B::float_shape(&x).dims(); + let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length])); let x = B::max_pool2d_with_indices( @@ -124,7 +124,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d( [0, padding], [1, dilation], ); - let [batch_size, channels, _, length] = B::float_shape(&x.output).dims(); + let [batch_size, channels, _, length] = x.output.shape().dims(); let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length])); let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); MaxPool1dWithIndices::new(output, indices) @@ -139,8 +139,8 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d( output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool1dBackward { - let [batch_size, channels, length_in] = B::float_shape(&x).dims(); - let [_, _, length_out] = B::float_shape(&output_grad).dims(); + let [batch_size, channels, length_in] = x.shape().dims(); + let [_, _, length_out] = output_grad.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); let grad_x = B::float_reshape( diff --git a/crates/burn-tensor/src/tensor/ops/modules/repeat_dim.rs b/crates/burn-tensor/src/tensor/ops/modules/repeat_dim.rs index 9deb736a41..847b2734d8 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/repeat_dim.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/repeat_dim.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, TensorKind}; +use crate::{backend::Backend, BasicOps, TensorKind, TensorMetadata}; use alloc::vec::Vec; pub(crate) fn repeat_with_slice_assign + BasicOps>( @@ -6,7 +6,7 @@ pub(crate) fn repeat_with_slice_assign + BasicOps K::Primitive { - let mut shape = K::shape(&tensor); + let mut shape = tensor.shape(); let device = K::device(&tensor); let original_dim_length = shape.dims[dim]; diff --git a/crates/burn-tensor/src/tensor/ops/modules/unfold.rs b/crates/burn-tensor/src/tensor/ops/modules/unfold.rs index 4c07a1830c..9bb8ff2055 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/unfold.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/unfold.rs @@ -1,6 +1,6 @@ use crate::backend::Backend; use crate::ops::FloatTensor; -use crate::{ElementConversion, Shape, TensorData}; +use crate::{ElementConversion, Shape, TensorData, TensorMetadata}; use alloc::vec; use alloc::vec::Vec; @@ -63,7 +63,7 @@ pub(crate) fn unfold4d_using_conv2d( kernel_size: [usize; 2], options: UnfoldOptions, ) -> FloatTensor { - let [_batch_size, in_channels, _in_height, _in_width] = B::float_shape(&x).dims(); + let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims(); let weight = create_unfolding_weight::(in_channels, kernel_size, &B::float_device(&x)); let unfolded = B::conv2d( x, @@ -77,7 +77,7 @@ pub(crate) fn unfold4d_using_conv2d( }, ); - let [batch_size, channels_out, out_height, out_width] = B::float_shape(&unfolded).dims(); + let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims(); B::float_reshape( unfolded, diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index e5f551dd76..5deb1ccfcb 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -6,7 +6,7 @@ use crate::tensor::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData}; use crate::{ tensor::api::chunk, tensor::api::narrow, tensor::api::split, tensor::api::split_with_sizes, - FloatDType, TensorPrimitive, + FloatDType, TensorMetadata, TensorPrimitive, }; use alloc::vec::Vec; use core::future::Future; @@ -85,17 +85,6 @@ pub trait FloatTensorOps { Self::float_add_scalar(Self::float_zeros(shape, device), fill_value) } - /// Gets the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn float_shape(tensor: &FloatTensor) -> Shape; - /// Converts the tensor to a data structure. /// /// # Arguments @@ -357,7 +346,7 @@ pub trait FloatTensorOps { /// /// The transposed tensor. fn float_transpose(tensor: FloatTensor) -> FloatTensor { - let ndims = Self::float_shape(&tensor).num_dims(); + let ndims = tensor.shape().num_dims(); Self::float_swap_dims(tensor, ndims - 2, ndims - 1) } @@ -764,7 +753,7 @@ pub trait FloatTensorOps { /// /// A scalar tensor with the mean of all elements in `tensor`. fn float_mean(tensor: FloatTensor) -> FloatTensor { - let num_elems = B::float_shape(&tensor).num_elements(); + let num_elems = tensor.shape().num_elements(); B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem()) } @@ -1057,7 +1046,7 @@ pub trait FloatTensorOps { /// /// A tensor with the maximum element of `tensor`. fn float_max(tensor: FloatTensor) -> FloatTensor { - let shape = B::float_shape(&tensor); + let shape = tensor.shape(); let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); B::float_max_dim(tensor, 0) @@ -1109,7 +1098,7 @@ pub trait FloatTensorOps { /// /// A tensor with the minimum element of `tensor`. fn float_min(tensor: FloatTensor) -> FloatTensor { - let shape = B::float_shape(&tensor); + let shape = tensor.shape(); let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); B::float_min_dim(tensor, 0) @@ -1280,7 +1269,7 @@ pub trait FloatTensorOps { /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn float_all(tensor: FloatTensor) -> BoolTensor { - let num_elems = B::float_shape(&tensor).num_elements(); + let num_elems = tensor.shape().num_elements(); let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::float_sum(B::bool_into_float(bool_tensor)); @@ -1300,7 +1289,7 @@ pub trait FloatTensorOps { /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. fn float_all_dim(tensor: FloatTensor, dim: usize) -> BoolTensor { - let num_elems = B::float_shape(&tensor).dims[dim]; + let num_elems = tensor.shape().dims[dim]; let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim); @@ -1317,7 +1306,7 @@ pub trait FloatTensorOps { /// /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`. fn float_sign(tensor: FloatTensor) -> FloatTensor { - let zeros = B::float_zeros(B::float_shape(&tensor), &B::float_device(&tensor)); + let zeros = B::float_zeros(tensor.shape(), &B::float_device(&tensor)); let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem()); let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem()); diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index a3f2c82b72..f35693d98f 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -2,6 +2,7 @@ mod activation; mod clone_invariance; mod module; mod ops; +mod primitive; mod quantization; mod stats; @@ -278,6 +279,9 @@ macro_rules! testgen_no_param { // test clone invariance burn_tensor::testgen_clone_invariance!(); + + // test primitive + burn_tensor::testgen_primitive!(); }; } diff --git a/crates/burn-tensor/src/tests/primitive.rs b/crates/burn-tensor/src/tests/primitive.rs new file mode 100644 index 0000000000..ca907e18a3 --- /dev/null +++ b/crates/burn-tensor/src/tests/primitive.rs @@ -0,0 +1,45 @@ +#[burn_tensor_testgen::testgen(primitive)] +mod tests { + use super::*; + use burn_tensor::{backend::Backend, DType, Element, Shape}; + + #[test] + fn should_support_float_dtype() { + let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]).into_primitive(); + + assert_eq!( + burn_tensor::TensorMetadata::shape(&tensor), + Shape::new([2, 3]) + ); + assert_eq!( + burn_tensor::TensorMetadata::dtype(&tensor), + ::FloatElem::dtype() // default float elem type + ); + } + + #[test] + fn should_support_int_dtype() { + let tensor = TestTensorInt::<2>::from([[0, -1, 2], [3, 4, -5]]).into_primitive(); + + assert_eq!( + burn_tensor::TensorMetadata::shape(&tensor), + Shape::new([2, 3]) + ); + assert_eq!( + burn_tensor::TensorMetadata::dtype(&tensor), + ::IntElem::dtype() // default int elem type + ); + } + + #[test] + fn should_support_bool_dtype() { + let tensor = + TestTensorBool::<2>::from([[false, true, true], [false, false, true]]).into_primitive(); + + assert_eq!( + burn_tensor::TensorMetadata::shape(&tensor), + Shape::new([2, 3]) + ); + assert_eq!(burn_tensor::TensorMetadata::dtype(&tensor), DType::Bool); + } +} diff --git a/examples/custom-cubecl-kernel/src/backward.rs b/examples/custom-cubecl-kernel/src/backward.rs index 594906c98a..3c66ae8e0e 100644 --- a/examples/custom-cubecl-kernel/src/backward.rs +++ b/examples/custom-cubecl-kernel/src/backward.rs @@ -8,7 +8,7 @@ use burn::{ ops::{broadcast_shape, Backward, Ops, OpsKind}, Autodiff, NodeID, }, - tensor::Shape, + tensor::{Shape, TensorMetadata}, }; use burn_jit::{FloatElement, IntElement, JitBackend, JitRuntime}; @@ -51,12 +51,12 @@ impl Backend for Autodiff { // Set our state. let (lhs_state, rhs_state, output, shape_bias) = ops.state; - let lhs = checkpointer.retrieve_node_output(lhs_state); - let rhs = checkpointer.retrieve_node_output(rhs_state); + let lhs: FloatTensor = checkpointer.retrieve_node_output(lhs_state); + let rhs: FloatTensor = checkpointer.retrieve_node_output(rhs_state); // Fetch shapes of our tensor to support broadcasting. - let shape_lhs = B::float_shape(&lhs); - let shape_rhs = B::float_shape(&rhs); + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); // Compute the gradient of the output using the already existing `relu_backward` // function in the basic Burn backend trait. @@ -114,7 +114,7 @@ impl Backend for Autodiff { // compute bound operation. let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); - let bias_shape = B::float_shape(&bias.primitive); + let bias_shape = bias.primitive.shape(); let output = B::fused_matmul_add_relu( lhs.primitive.clone(), diff --git a/examples/custom-cubecl-kernel/src/forward.rs b/examples/custom-cubecl-kernel/src/forward.rs index 886b7146ed..a8bf17fcd7 100644 --- a/examples/custom-cubecl-kernel/src/forward.rs +++ b/examples/custom-cubecl-kernel/src/forward.rs @@ -47,8 +47,13 @@ impl Backend for JitBackend()); // Create the output tensor primitive. - let output = - JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); + let output = JitTensor::new_contiguous( + lhs.client.clone(), + lhs.device.clone(), + shape_out, + buffer, + F::dtype(), + ); // Declare the wgsl workgroup with the number of cubes in x, y and z. let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; @@ -62,10 +67,10 @@ impl Backend for JitBackend(1), + rhs.as_tensor_arg::(1), + bias.as_tensor_arg::(1), + output.as_tensor_arg::(1), ); // Return the output tensor. diff --git a/examples/custom-image-dataset/src/inference.rs b/examples/custom-image-dataset/src/inference.rs new file mode 100644 index 0000000000..3f5983927b --- /dev/null +++ b/examples/custom-image-dataset/src/inference.rs @@ -0,0 +1,31 @@ +use burn::{ + data::{ + dataloader::batcher::Batcher, + dataset::vision::{Annotation, ImageDatasetItem}, + }, + module::Module, + record::{CompactRecorder, Recorder}, + tensor::backend::Backend, +}; + +use crate::{data::ClassificationBatcher, model::Cnn}; + +const NUM_CLASSES: u8 = 10; + +pub fn infer(artifact_dir: &str, device: B::Device, item: ImageDatasetItem) { + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into(), &device) + .expect("Trained model should exist"); + + let model: Cnn = Cnn::new(NUM_CLASSES.into(), &device).load_record(record); + + let mut label = 0; + if let Annotation::Label(category) = item.annotation { + label = category; + }; + let batcher = ClassificationBatcher::new(device); + let batch = batcher.batch(vec![item]); + let output = model.forward(batch.images); + let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); + println!("Predicted {} Expected {:?}", predicted, label); +} diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index b4532cb5a4..b9032413bc 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -11,7 +11,7 @@ use burn::{ }, wgpu::{FloatElement, IntElement, JitBackend, WgpuRuntime}, }, - tensor::Shape, + tensor::{Shape, TensorMetadata}, }; impl AutodiffBackend for Autodiff> {} @@ -54,12 +54,12 @@ impl Backend for Autodiff { // Set our state. let (lhs_state, rhs_state, output, shape_bias) = ops.state; - let lhs = checkpointer.retrieve_node_output(lhs_state); - let rhs = checkpointer.retrieve_node_output(rhs_state); + let lhs: FloatTensor = checkpointer.retrieve_node_output(lhs_state); + let rhs: FloatTensor = checkpointer.retrieve_node_output(rhs_state); // Fetch shapes of our tensor to support broadcasting. - let shape_lhs = B::float_shape(&lhs); - let shape_rhs = B::float_shape(&rhs); + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); // Compute the gradient of the output using the already existing `relu_backward` // function in the basic Burn backend trait. @@ -115,7 +115,7 @@ impl Backend for Autodiff { // during the backward pass. Here we choose to save it in the state because it's a compute bound operation. let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); - let bias_shape = B::float_shape(&bias.primitive); + let bias_shape = bias.primitive.shape(); let output = B::fused_matmul_add_relu( lhs.primitive.clone(), diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index 6acdd27350..c8476230d2 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -80,14 +80,19 @@ impl Backend for JitBackend { .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. - let output = - JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); + let output = JitTensor::new_contiguous( + lhs.client.clone(), + lhs.device.clone(), + shape_out, + buffer, + F::dtype(), + ); // Create the kernel. let kernel = FusedMatmulAddRelu::::new(cube_dim); // Build info buffer with tensor information needed by the kernel, such as shapes and strides. - let info = build_info(&[&lhs, &rhs, &output]); + let info = build_info::<_, F>(&[&lhs, &rhs, &output]); let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); // Declare the wgsl workgroup with the number of cubes in x, y and z.