-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Provided one global DeviceConnector instance. Moved typedefs from uti…
…ls to a new file. (#131) Also moved `Workload` to a new file.
- Loading branch information
1 parent
3f849d6
commit afa672e
Showing
11 changed files
with
75 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from enum import Enum | ||
from typing import Union | ||
|
||
import jax | ||
from flax import linen, nnx | ||
|
||
# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc. | ||
Tensor = Union[jax.Array] | ||
|
||
# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc. | ||
# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some | ||
# huggingface models. | ||
Model = Union[nnx.Module, linen.Module] | ||
|
||
|
||
class Framework(Enum): | ||
JAX = "jax" | ||
TORCH = "torch" | ||
NUMPY = "numpy" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Callable, Mapping, Optional, Sequence | ||
|
||
|
||
@dataclass | ||
class Workload: | ||
""" | ||
Convenience dataclass storing a callable and its positional and keyword arguments. | ||
""" | ||
|
||
executable: Callable | ||
args: Sequence[Any] | ||
kwargs: Optional[Mapping[str, Any]] = None | ||
|
||
def __post_init__(self): | ||
# If kwargs is None, initialize it to an empty dictionary. | ||
if self.kwargs is None: | ||
self.kwargs = {} | ||
|
||
def execute(self) -> Any: | ||
"""Calls callable passing stored args and kwargs directly.""" | ||
return self.executable(*self.args, **self.kwargs) |