Skip to content

Commit

Permalink
refactor: Remove backend array creation/conversion wrapper (synnada-a…
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored Dec 27, 2024
1 parent 6165786 commit d36cf87
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 745 deletions.
30 changes: 26 additions & 4 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,30 @@ def arange(
dtype: core.Dtype | None = None,
) -> DataType: ...

def arange(self, *args: int | float, **kwargs: Any) -> DataType:
raise NotImplementedError("arange is not implemented!")
def arange(self, *args: int | float, **kwargs) -> DataType:
"""Generate an array of evenly spaced values within a specified range."""
if len(args) == 0:
raise RuntimeError(
"arange() missing 1 required positional argument: 'stop'"
)
elif len(args) == 1:
return self._arange(0, args[0], 1, **kwargs) # type: ignore
elif len(args) == 2:
if args[0] >= args[1]:
return self.array([])

return self._arange( # type: ignore
args[0], args[1], 1, **kwargs
)
elif len(args) == 3:
return self._arange( # type: ignore
args[0], args[1], args[2], **kwargs
)
else:
raise RuntimeError(
"arange() accepts 1 to 3 positional arguments,"
" but `f{len(args)}` were provided"
)

def flatten(
self, input: DataType, start_dim: int = 0, end_dim: int = -1
Expand Down Expand Up @@ -459,7 +481,7 @@ def linspace(
self,
start: int | float | bool | DataType,
stop: int | float | bool | DataType,
steps: int | DataType,
steps: int,
dtype: core.Dtype | None = None,
) -> DataType:
"""
Expand Down Expand Up @@ -1349,7 +1371,7 @@ def linspace(
self,
start: int | float | bool | DataType,
stop: int | float | bool | DataType,
steps: int | DataType,
steps: int,
dtype: core.Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> DataType:
Expand Down
Loading

0 comments on commit d36cf87

Please sign in to comment.