diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index d2562b7..3245bde 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -146,8 +146,30 @@ def arange( dtype: core.Dtype | None = None, ) -> DataType: ... - def arange(self, *args: int | float, **kwargs: Any) -> DataType: - raise NotImplementedError("arange is not implemented!") + def arange(self, *args: int | float, **kwargs) -> DataType: + """Generate an array of evenly spaced values within a specified range.""" + if len(args) == 0: + raise RuntimeError( + "arange() missing 1 required positional argument: 'stop'" + ) + elif len(args) == 1: + return self._arange(0, args[0], 1, **kwargs) # type: ignore + elif len(args) == 2: + if args[0] >= args[1]: + return self.array([]) + + return self._arange( # type: ignore + args[0], args[1], 1, **kwargs + ) + elif len(args) == 3: + return self._arange( # type: ignore + args[0], args[1], args[2], **kwargs + ) + else: + raise RuntimeError( + "arange() accepts 1 to 3 positional arguments," + " but `f{len(args)}` were provided" + ) def flatten( self, input: DataType, start_dim: int = 0, end_dim: int = -1 @@ -459,7 +481,7 @@ def linspace( self, start: int | float | bool | DataType, stop: int | float | bool | DataType, - steps: int | DataType, + steps: int, dtype: core.Dtype | None = None, ) -> DataType: """ @@ -1349,7 +1371,7 @@ def linspace( self, start: int | float | bool | DataType, stop: int | float | bool | DataType, - steps: int | DataType, + steps: int, dtype: core.Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> DataType: diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 56ff6e3..8573674 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -15,7 +15,6 @@ import math import os from collections.abc import Callable, Sequence -from functools import partial from typing import Any, overload import jax @@ -141,101 +140,6 @@ def block_until_ready(self, data: jax.Array) -> jax.Array | None: """ return data.block_until_ready() - def _creation_fn_wrapper( - self, fn: Callable[..., jax.Array] - ) -> Callable[..., jax.Array]: - """ - Wrapper for array creation functions. - - Parameters - ---------- - fn: Callable - The original array creation function. - - Returns - ------- - Callable - A wrapped function that creates arrays with specified dtype and device. - - Notes - ----- - Ensures that arrays are created with the correct dtype and device. - """ - - array_conversion_fn = partial( - utils.creation_fn_wrapper, - fn=fn, - device=self._device, - precision=self.precision, - ) - array_conversion_fn = partial(self._parallelize, fn=array_conversion_fn) - - return array_conversion_fn - - def _conversion_fn_wrapper( - self, fn: Callable[..., jax.Array] - ) -> Callable[..., jax.Array]: - """ - Wrapper for array conversion functions. - - Parameters - ---------- - fn: Callable - The original array conversion function. - - Returns - ------- - Callable - A wrapped function that converts arrays with specified dtype and device. - - Notes - ----- - Handles the conversion of arrays between different dtypes and devices. - - If dtype is provided, it uses `utils._handle_dtype` to ensure a valid dtype. - If the input data is a JAX Array, it ensures it's on the specified device. - If dtype is not provided, uses the default device and handles data precision. - """ - array_conversion_fn = partial( - utils.conversion_fn_wrapper, - fn=fn, - device=self._device, - precision=self.precision, - ) - array_conversion_fn = partial(self._parallelize, fn=array_conversion_fn) - - return array_conversion_fn - - def _parallelize( - self, - *args: Any, - fn: Callable[..., jax.Array], - device_mesh: tuple[int, ...], - **kwargs: Any, - ) -> jax.Array: - """ - Parallelizes the function's return tensor across devices. - - Parameters - ---------- - fn : Callable - The function whose return tensor will be parallelized. - - device_mesh : tuple[int, ...], optional - A tuple specifying the device mesh for parallelization. - If not provided, the default device mesh is used. - - Returns - ------- - Callable - Return tensor parallelized across the specified device mesh. - """ - - tensor: jax.Array = fn(*args, **kwargs) - if self._parallel_manager is None: - return tensor - return self._parallel_manager.parallelize(tensor, device_mesh) - def _register_callable( self, fn: Callable[..., Any], fn_name: str, jit: bool = False ): @@ -264,13 +168,17 @@ def array( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - result = self._conversion_fn_wrapper(jax.numpy.array)( - input, dtype=_dtype, device_mesh=device_mesh - ) - return result + _dtype = utils.determine_dtype(input, dtype, self.precision) + + with jax.default_device(self.device): + array = jax.numpy.array( + input, dtype=utils.dtype_map[_dtype], device=self.device + ) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def zeros( self, @@ -278,14 +186,16 @@ def zeros( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.numpy.zeros)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.numpy.zeros(_shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def ones( self, @@ -293,14 +203,16 @@ def ones( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.numpy.ones)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.numpy.ones(_shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def ones_like( self, @@ -309,13 +221,15 @@ def ones_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - result = self._creation_fn_wrapper(jax.numpy.ones_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) - return result + _dtype = self._process_dtype(dtype) if dtype is not None else None + + with jax.default_device(self.device): + array = jax.numpy.ones_like(input, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def zeros_like( self, @@ -324,13 +238,15 @@ def zeros_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - result = self._creation_fn_wrapper(jax.numpy.zeros_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) - return result + _dtype = self._process_dtype(dtype) if dtype is not None else None + + with jax.default_device(self.device): + array = jax.numpy.zeros_like(input, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def randn( self, @@ -341,14 +257,17 @@ def randn( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.random.normal)( - prng_key, _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.random.normal(prng_key, _shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def rand( self, @@ -359,14 +278,17 @@ def rand( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.random.uniform)( - prng_key, _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.random.normal(prng_key, _shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def randint( self, @@ -379,19 +301,17 @@ def randint( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.random.randint)( - prng_key, - _shape, - low, - high, - dtype=_dtype, - device_mesh=device_mesh, - ) - return result + + with jax.default_device(self.device): + array = jax.random.randint(prng_key, _shape, low, high, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def rand_uniform( self, @@ -404,47 +324,56 @@ def rand_uniform( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(jax.random.uniform)( - prng_key, - _shape, - dtype=_dtype, - minval=low, - maxval=high, - device_mesh=device_mesh, - ) + + with jax.default_device(self.device): + array = jax.random.uniform(prng_key, _shape, _dtype, low, high) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def _arange( self, - *args: int | float, + start: int | float, + stop: int | float, + step: int | float, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, **kwargs: Any, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(jax.numpy.arange)( - *args, dtype=_dtype, device_mesh=device_mesh + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int ) + _dtype = self._process_dtype(dtype, default_type) + + with jax.default_device(self.device): + array = jax.numpy.arange(start, stop, step, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def linspace( self, start: int | float | bool | jax.numpy.ndarray, stop: int | float | bool | jax.numpy.ndarray, - steps: int | jax.numpy.ndarray, + steps: int, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(jax.numpy.linspace)( - start, stop, steps, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) + with jax.default_device(self.device): + array = jax.numpy.linspace(start, stop, steps, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def flatten( self, input: jax.Array, start_dim: int = 0, end_dim: int = -1 @@ -729,3 +658,18 @@ def jacfwd( self, fn: Callable[..., dict[str, jax.Array]] ) -> Callable[..., dict[str, jax.Array]]: return jax.jacfwd(fn) + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> jax.numpy.dtype[Any]: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") + + def _get_defualt_type(self): + return getattr(self, f"float{self.precision}") diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index 2e332b3..9c5139b 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp +import numpy as np from jax import vmap from .... import core @@ -332,66 +333,6 @@ def _parse_device_string(device: str): return backend, device_idx -def handle_dtype(dtype: str | core.Dtype | jnp.dtype[Any]) -> jnp.dtype[Any]: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, str) and dtype in dtype_map: - return dtype_map[dtype] - else: - try: - return jnp.dtype(dtype) - except TypeError as err: - raise TypeError(f"Provided data type '{dtype}' not understood") from err - - -def creation_fn_wrapper( - *args: Any, - fn: Callable[..., jax.Array], - dtype: core.Dtype | jnp.dtype[Any] | None = None, - device: str, - precision: int, - **kwargs: Any, -): - _device = get_device(device) - - if dtype is not None: - dtype = handle_dtype(dtype) - with jax.default_device(_device): - data = fn(*args, dtype=dtype, **kwargs) - else: - with jax.default_device(_device): - data = fn(*args, **kwargs) - data = handle_data_precision(data, precision) - return data - - -def conversion_fn_wrapper( - data: Any, - *args: Any, - fn: Callable[..., jax.Array], - device: str, - precision: int, - dtype: core.Dtype | jnp.dtype[Any] | None = None, - **kwargs: Any, -): - _device = get_device(device) - - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if next(iter(data.devices())) != _device: - data = jax.device_put(data, _device) - if dtype is not None: - return data.astype(dtype) - return handle_data_precision(data, precision) - else: - with jax.default_device(_device): - _data = fn(data, *args, dtype=dtype, **kwargs) - if dtype is None: # User did not specify dtype explicitly - return handle_data_precision(_data, precision) - return _data - - def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: _dtype = data.dtype # Do not make any changes to boolean types. @@ -504,3 +445,19 @@ def calculate_cross_entropy_class_weights( shape[1] = input.shape[1] _weights = _weights.reshape(shape) return _weights + + +def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str: + if isinstance(dtype, core.Dtype): + return dtype.name + + if isinstance(input, jax.Array): + dtype_name = "".join( + char for char in input.dtype.__str__() if not char.isdigit() + ) + elif isinstance(input, (np.ndarray | np.generic)): + dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) + else: + dtype_name = find_dominant_type(input).__name__ + + return dtype_name + str(precision) if dtype_name != "bool" else "bool" diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index ec2a1a8..41ea88f 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -93,24 +93,6 @@ def to_device( def block_until_ready(self, data: mx.array): mx.eval(data) - def _creation_fn_wrapper( - self, fn: Callable[..., mx.array] - ) -> Callable[..., mx.array]: - return partial( - utils.creation_fn_wrapper, - fn=fn, - precision=self.precision, - ) - - def _conversion_fn_wrapper( - self, fn: Callable[..., mx.array] - ) -> Callable[..., mx.array]: - return partial( - utils.conversion_fn_wrapper, - fn=fn, - precision=self.precision, - ) - def _handle_dict_type_fun( self, *inputs: mx.array, @@ -188,40 +170,34 @@ def _handle_sequence_type_fun( return [output] def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._conversion_fn_wrapper(mx.array)(input, dtype=_dtype) + _dtype = utils.determine_dtype(input, dtype, self.precision) + return mx.array(input, dtype=utils.dtype_map[_dtype]) def zeros( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.zeros)(shape=_shape, dtype=_dtype) + return mx.zeros(_shape, dtype=_dtype) def ones( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.ones)(shape=_shape, dtype=_dtype) + return mx.ones(_shape, dtype=_dtype) def ones_like(self, input: mx.array, *, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.ones_like)(input, dtype=_dtype) + if dtype is not None: + raise ValueError("dtype argument is not supported for ones_like") + + return mx.ones_like(input) def zeros_like(self, input: mx.array, *, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.zeros_like)(input, dtype=_dtype) + if dtype is not None: + raise ValueError("dtype argument is not supported for ones_like") + + return mx.zeros_like(input) def randn( self, @@ -229,11 +205,9 @@ def randn( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.normal)(shape=_shape, dtype=_dtype) + return mx.random.normal(shape=_shape, dtype=_dtype) def rand( self, @@ -241,11 +215,9 @@ def rand( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.uniform)(shape=_shape, dtype=_dtype) + return mx.random.uniform(shape=_shape, dtype=_dtype) def randint( self, @@ -255,13 +227,9 @@ def randint( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.randint)( - low=low, high=high, shape=_shape, dtype=_dtype - ) + return mx.random.randint(low, high, shape=_shape, dtype=_dtype) def rand_uniform( self, @@ -271,31 +239,33 @@ def rand_uniform( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.uniform)( - low=low, high=high, shape=_shape, dtype=_dtype + return mx.random.uniform(low, high, shape=_shape, dtype=_dtype) + + def _arange( + self, + start: int | float, + stop: int | float, + step: int | float, + dtype: Dtype | None = None, + ) -> mx.array: + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int ) + _dtype = self._process_dtype(dtype, default_type) - def arange(self, *args: float | int, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.arange)(*args, dtype=_dtype) + return mx.arange(start, stop, step, dtype=_dtype) def linspace( self, start: int | float | bool | mx.array, stop: int | float | bool | mx.array, - steps: int | mx.array, + steps: int, dtype: Dtype | None = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.linspace)(start, stop, steps, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return mx.linspace(start, stop, steps, dtype=_dtype) def flatten( self, input: mx.array, start_dim: int = 0, end_dim: int = -1 @@ -673,3 +643,15 @@ def vmap( # type: ignore #mypy bug self, fn: Callable[[mx.array], mx.array] ) -> Callable[[mx.array], mx.array]: return mx.vmap(fn) + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> mx.Dtype: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") diff --git a/mithril/backends/with_autograd/mlx_backend/utils.py b/mithril/backends/with_autograd/mlx_backend/utils.py index e388b7a..14d908d 100644 --- a/mithril/backends/with_autograd/mlx_backend/utils.py +++ b/mithril/backends/with_autograd/mlx_backend/utils.py @@ -52,54 +52,6 @@ def get_device(device: str): return mx.Device(getattr(mx, device), 0) -def creation_fn_wrapper( - *args: Any, - fn: Callable[..., mx.array], - dtype: core.Dtype | mx.Dtype | None = None, - precision: int, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - data = fn(*args, dtype=dtype, **kwargs) - else: - data = fn(*args, **kwargs) - data = handle_data_precision(data, precision) - return data - - -def conversion_fn_wrapper( - data: Any, - *args: Any, - fn: Callable[..., mx.array], - precision: int, - dtype: mx.Dtype | None = None, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if dtype is not None: - return data.astype(dtype) - return handle_data_precision(data, precision) - else: - _data = fn(data, *args, dtype=dtype, **kwargs) - if dtype is None: # User did not specify dtype explicitly - return handle_data_precision(_data, precision) - return _data - - -def handle_dtype(dtype: Any) -> Any: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, str) and dtype in dtype_map: - return dtype_map[dtype] - elif isinstance(dtype, mx.Dtype): - return dtype - else: - raise TypeError(f"Provided data type '{dtype}' not understood") - - def handle_data_precision(data: mx.array, precision: int) -> mx.array: _dtype = data.dtype # Do not make any changes to boolean types. @@ -420,6 +372,22 @@ def get_submatrices2d( ) +def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str: + if isinstance(dtype, core.Dtype): + return dtype.name + + if isinstance(input, mx.array): + dtype_name = "".join( + char for char in input.dtype.__str__().split(".")[-1] if not char.isdigit() + ) + elif isinstance(input, (np.ndarray | np.generic)): + dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) + else: + dtype_name = find_dominant_type(input).__name__ + + return dtype_name + str(precision) if dtype_name != "bool" else "bool" + + def get_type(input: int | float | bool | Sequence[Any], precision: int) -> mx.Dtype: type = find_dominant_type(input).__name__ if type == "bool": diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index 63e0fac..7fd130b 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Callable, Sequence -from functools import partial from typing import Any, overload import torch @@ -24,7 +23,6 @@ from torch._functorch.eager_transforms import jacfwd as torch_jacfwd from torch._functorch.eager_transforms import jacrev as torch_jacrev from torch._functorch.eager_transforms import vjp as torch_vjp -from torch.distributed._tensor import DTensor from ....core import Dtype from ...backend import PadWidthType, ParallelBackend @@ -139,101 +137,6 @@ def empty_cache(self) -> None: pass # print(f"Warning: empty_cache is not implemented for {self.device_type}") - def _creation_fn_wrapper( - self, fn: Callable[..., torch.Tensor] - ) -> Callable[..., torch.Tensor]: - """ - Wrapper for PyTorch tensor creation functions. - - Parameters - ---------- - fn: Callable - The original tensor creation function. - - Returns - ------- - Callable - A wrapped function that creates tensors with specified dtype and device. - - Notes - ----- - This wrapper ensures that tensors are created with the correct dtype - and on the specified device. - """ - - array_creation_fn = partial( - utils.creation_fn_wrapper_inner, - fn=fn, - device=self._device, - precision=self.precision, - ) - array_creation_fn = partial(self._parallelize, fn=array_creation_fn) - - return array_creation_fn - - def _conversion_fn_wrapper( - self, fn: Callable[..., torch.Tensor] - ) -> Callable[..., torch.Tensor]: - """ - Wrapper for PyTorch tensor conversion functions. - - Parameters - ---------- - fn: Callable - The original tensor conversion function. - - Returns - ------- - Callable - A wrapped function that converts tensors with specified dtype and device. - - Notes - ----- - Wrapper handles the conversion of tensors between different dtypes and devices. - """ - - array_conversion_fn = partial( - utils.conversion_fn_wrapper_inner, - fn=fn, - device=self._device, - precision=self.precision, - ) - array_conversion_fn = partial(self._parallelize, fn=array_conversion_fn) - - return array_conversion_fn - - def _parallelize( - self, - *args: Any, - fn: Callable[..., torch.Tensor], - device_mesh: tuple[int] | None, - **kwargs: Any, - ) -> DTensor | torch.Tensor: - """ - Parallelizes the function's return tensor across devices. - - Parameters - ---------- - fn : Callable - The function whose return tensor will be parallelized. - device_mesh : tuple[int, ...], optional - A tuple specifying the device mesh for parallelization. - If not provided, the default device mesh is used. - - Returns - ------- - Callable - Returns tensor parallelized across the specified device mesh. - """ - tensor: torch.Tensor = fn(*args, **kwargs) - if self._parallel_manager is None: - # TODO: raise device_mesh should be None - return tensor - - return self._parallel_manager.parallelize( - tensor, self.base_device_mesh, device_mesh - ) - def _register_callable( self, fn: Callable[..., torch.Tensor], fn_name: str, jit: bool = False ): @@ -298,12 +201,15 @@ def array( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._conversion_fn_wrapper(torch.tensor)( - input, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = utils.determine_dtype(input, dtype, self.precision) + + array = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + + return array def zeros( self, @@ -311,13 +217,16 @@ def zeros( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.zeros)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) + + array = torch.zeros(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + + return array def ones( self, @@ -325,13 +234,15 @@ def ones( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.ones)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) + + array = torch.ones(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def ones_like( self, @@ -340,12 +251,14 @@ def ones_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.ones_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) if dtype is not None else None + + array = torch.ones_like(input, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def zeros_like( self, @@ -354,12 +267,14 @@ def zeros_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.zeros_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) if dtype is not None else None + + array = torch.zeros_like(input, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def randn( self, @@ -368,13 +283,16 @@ def randn( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.randn)( - size=_shape, dtype=_dtype, device_mesh=device_mesh - ) + + # TODO: PRNG key is not used + array = torch.randn(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def rand( self, @@ -383,13 +301,15 @@ def rand( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.rand)( - size=_shape, dtype=_dtype, device_mesh=device_mesh - ) + + array = torch.rand(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def randint( self, @@ -400,17 +320,15 @@ def randint( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.randint)( - low, - high, - size=_shape, - dtype=_dtype, - device_mesh=device_mesh, - ) + + array = torch.randint(low, high, _shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def rand_uniform( self, @@ -427,32 +345,42 @@ def rand_uniform( def _arange( self, - *args: int | float, + start: int | float, + stop: int | float, + step: int | float, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, **kwargs: int | float, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.arange)( - *args, dtype=_dtype, device_mesh=device_mesh + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int ) + _dtype = self._process_dtype(dtype, default_type) + + array = torch.arange(start, stop, step, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + + return array def linspace( self, start: int | float | bool | torch.Tensor, stop: int | float | bool | torch.Tensor, - steps: int | torch.Tensor, + steps: int, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.linspace)( - start, stop, steps, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) + + array = torch.linspace(start, stop, steps, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def flatten( self, input: torch.Tensor, start_dim: int = 0, end_dim: int = -1 @@ -718,3 +646,15 @@ def jacrev(self, fn: Callable[..., dict[str, torch.Tensor]]) -> Callable: def jacfwd(self, fn: Callable[..., dict[str, torch.Tensor]]) -> Callable: return torch_jacfwd(fn) + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> torch.dtype: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") diff --git a/mithril/backends/with_autograd/torch_backend/utils.py b/mithril/backends/with_autograd/torch_backend/utils.py index b525417..85b326e 100644 --- a/mithril/backends/with_autograd/torch_backend/utils.py +++ b/mithril/backends/with_autograd/torch_backend/utils.py @@ -36,6 +36,7 @@ AVAILABLE_BACKEND_TYPES = ["cpu", "cuda"] ArrayType = torch.Tensor +NestedTensorType = int | float | bool | Sequence["NestedTensorType"] dtype_map: dict[str, torch.dtype] = { "int16": torch.int16, "int32": torch.int32, @@ -186,77 +187,7 @@ def get_available_devices() -> list[str]: return devices -def handle_dtype(dtype: core.Dtype | torch.dtype | str) -> Any: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, torch.dtype): - return dtype - elif dtype in dtype_map: - return dtype_map[dtype] - raise TypeError(f"Provided data type '{dtype}' not understood") - - -def creation_fn_wrapper_inner( - *args: Any, - dtype: core.Dtype | torch.dtype | str | None = None, - fn: Callable[..., torch.Tensor], - device: str, - precision: int, - device_mesh: tuple[int, ...] | None = None, - **kwargs: Any, -): - _device = get_device(device) - if dtype is not None: - dtype = handle_dtype(dtype) - data = fn(*args, dtype=dtype, device=_device, **kwargs) - else: - data = fn(*args, device=_device, **kwargs) - data = handle_data_precision(data, precision=precision) - - return data - - -def conversion_fn_wrapper_inner( - data: Any, - *args: Any, - dtype: torch.dtype | str | None = None, - fn: Callable[..., torch.Tensor], - device: str, - precision: int, - **kwargs: Any, -) -> torch.Tensor: - _device = get_device(device) - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if data.device != _device: - data = data.to(_device) - if dtype is not None: - return data.type(dtype) - return handle_data_precision(data, precision=precision) - elif isinstance(data, np.ndarray): - _data = fn(data, *args, dtype=dtype, device=_device, **kwargs) - if ( - dtype is None and _data.dtype != torch.bool - ): # User did not specify dtype explicitly - return handle_data_precision(_data, precision=precision) - return _data - else: - # To determine subtype we are creating tensor twice in worst case - _data = fn(data, *args, dtype=dtype, device=device, **kwargs) - if ( - dtype is None - and get_precision(_data) != precision - and _data.dtype != torch.bool - ): - subtype = get_subtype(_data) - _dtype = getattr(torch, f"{subtype}{precision}") - _data = fn(data, *args, dtype=_dtype, device=device, **kwargs) - return _data - return _data - - -def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: +def handle_data_precision(data: torch.Tensor, precision: int) -> torch.Tensor: _dtype = data.dtype dtype: torch.dtype # Do not make any changes to boolean types. @@ -276,7 +207,7 @@ def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: return data -def handle_data_dtype(data: ArrayType, dtype: core.Dtype | int) -> ArrayType: +def handle_data_dtype(data: torch.Tensor, dtype: core.Dtype | int) -> torch.Tensor: dtype = core.Dtype(dtype) if data.dtype != dtype_map[dtype.name]: @@ -286,11 +217,11 @@ def handle_data_dtype(data: ArrayType, dtype: core.Dtype | int) -> ArrayType: return data -def get_precision(data: ArrayType) -> int: +def get_precision(data: torch.Tensor) -> int: return data.dtype.itemsize * 8 -def get_subtype(data: ArrayType) -> str: +def get_subtype(data: torch.Tensor) -> str: # TODO: cover uint dtypes if not torch.is_floating_point(data) and not torch.is_complex(data): return "int" @@ -753,7 +684,20 @@ def check_device_mesh(base_mesh: DeviceMesh, device_mesh: tuple[int, ...]): ) -NestedTensorType = int | float | bool | Sequence["NestedTensorType"] +def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str: + if isinstance(dtype, core.Dtype): + return dtype.name + + if isinstance(input, torch.Tensor): + dtype_name = "".join( + char for char in input.dtype.__str__().split(".")[1] if not char.isdigit() + ) + elif isinstance(input, (np.ndarray | np.generic)): + dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) + else: + dtype_name = find_dominant_type(input).__name__ + + return dtype_name + str(precision) if dtype_name != "bool" else "bool" def get_type(input: NestedTensorType, precision: int): diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index ae870f2..d077173 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Callable -from functools import partial from typing import Any import numpy as np @@ -101,52 +100,6 @@ def set_seed(self, seed: int): self.seed = seed np.random.seed(seed) - def _creation_fn_wrapper( - self, fn: Callable[..., np.ndarray[Any, Any]] - ) -> Callable[..., np.ndarray[Any, Any]]: - """ - Wrapper for NumPy array creation functions. - - Parameters - ---------- - fn: Callable - The original array creation function. - - Returns - ------- - Callable - A wrapped function that creates NumPy arrays with specified dtype. - - Notes - ----- - This wrapper ensures that NumPy arrays are created with the correct dtype. - """ - return partial(utils.creation_fn_wrapper, fn=fn, precision=self.precision) - - def _conversion_fn_wrapper( - self, fn: Callable[..., np.ndarray[Any, Any]] - ) -> Callable[..., np.ndarray[Any, Any]]: - """ - Wrapper for NumPy array conversion functions. - - Parameters - ---------- - fn: Callable - The original array conversion function. - - Returns - ------- - Callable - A wrapped function that converts arrays to NumPy arrays with - specified dtype. - - Notes - ----- - This wrapper handles the conversion of arrays to NumPy arrays with - different dtypes. - """ - return partial(utils.conversion_fn_wrapper, fn=fn, precision=self.precision) - def accumulate_grads( self, gradient: np.ndarray[Any, Any], @@ -156,45 +109,36 @@ def accumulate_grads( ) -> np.ndarray[Any, Any]: return utils.accumulate_grads(gradient, input, cache, idx) - def array(self, input: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._conversion_fn_wrapper(np.array)(input, dtype=_dtype) + def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]: + _dtype = utils.determine_dtype(data, dtype, self.precision) + + return np.array(data, dtype=utils.dtype_map[_dtype]) def zeros( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.zeros)(shape=_shape, dtype=_dtype) + return np.zeros(_shape, dtype=_dtype) def ones( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.ones)(shape=_shape, dtype=_dtype) + return np.ones(_shape, dtype=_dtype) def ones_like( self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.ones_like)(input, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return np.ones_like(input, dtype=_dtype) def zeros_like( self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.zeros_like)(input, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return np.zeros_like(input, dtype=_dtype) def randn( self, @@ -202,11 +146,9 @@ def randn( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.randn)(*_shape, dtype=_dtype) + return np.array(np.random.randn(*_shape), dtype=_dtype) def rand( self, @@ -214,11 +156,9 @@ def rand( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.rand)(*_shape, dtype=_dtype) + return np.array(np.random.rand(*_shape), dtype=_dtype) def randint( self, @@ -228,13 +168,9 @@ def randint( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.randint)( - low=low, high=high, size=_shape, dtype=_dtype - ) + return np.random.randint(low, high, size=_shape).astype(_dtype) def rand_uniform( self, @@ -244,34 +180,33 @@ def rand_uniform( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.uniform)( - low=low, high=high, size=_shape, dtype=_dtype - ) + return np.array(np.random.uniform(low, high, size=_shape), dtype=_dtype) - def arange( - self, *args: int | float, dtype: Dtype | None = None + def _arange( + self, + start: int | float, + stop: int | float, + step: int | float, + dtype: Dtype | None = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.arange)(*args, dtype=_dtype) + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int + ) + _dtype = self._process_dtype(dtype, default_type) + return np.arange(start, stop, step, dtype=_dtype) def linspace( self, start: int | float | bool | np.ndarray[Any, Any], stop: int | float | bool | np.ndarray[Any, Any], - steps: int | np.ndarray[Any, Any], + steps: int, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.linspace)(start, stop, steps, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return np.linspace(start, stop, steps, dtype=_dtype) def flatten( self, input: np.ndarray[Any, Any], start_dim: int = 0, end_dim: int = -1 @@ -459,3 +394,15 @@ def multinomial( samples = np.squeeze(samples, axis=0) return samples + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> np.dtype[Any]: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index c7a101c..42d0076 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -308,55 +308,6 @@ def calc_input_slices( return slices -def handle_dtype(dtype: Any) -> Any: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, str) and dtype in dtype_map: - return dtype_map[dtype] - else: - try: - return np.dtype(dtype) - except TypeError as err: - raise TypeError(f"Provided data type '{dtype}' not understood") from err - - -def creation_fn_wrapper( - *args: Any, - fn: Callable[..., np.ndarray[Any, Any]], - precision: int, - dtype: core.Dtype | np.dtype[Any] | None = None, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - data = fn(*args, dtype=dtype, **kwargs) - else: - data = fn(*args, **kwargs) - data = handle_data_precision(data, precision=precision) - return data - - -def conversion_fn_wrapper( - data: Any, - *args: Any, - fn: Callable[..., np.ndarray[Any, Any]], - precision: int, - dtype: np.dtype[Any] | None = None, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if dtype is not None: - return data.astype(dtype) - return handle_data_precision(data, precision=precision) - else: - _data = fn(data, *args, dtype=dtype, **kwargs) - if dtype is None: - return handle_data_precision(_data, precision=precision) - return _data - - def handle_data_precision( data: np.ndarray[Any, Any], precision: int ) -> np.ndarray[Any, Any]: @@ -497,6 +448,18 @@ def calculate_cross_entropy_class_weights( return _weights +def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str: + if isinstance(dtype, core.Dtype): + return dtype.name + + if isinstance(input, (np.ndarray | np.generic)): + dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) + else: + dtype_name = find_dominant_type(input).__name__ + + return dtype_name + str(precision) if dtype_name != "bool" else "bool" + + def get_type( input: int | float | bool | Sequence[int | float | bool | Sequence[Any]], precision: int,