Skip to content

Commit

Permalink
refactor: Remove factory inputs (synnada-ai#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
kberat-synnada authored Dec 30, 2024
1 parent 9b7b730 commit a0e6b1b
Show file tree
Hide file tree
Showing 12 changed files with 568 additions and 1,032 deletions.
4 changes: 1 addition & 3 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def array(
_dtype = utils.determine_dtype(input, dtype, self.precision)

with jax.default_device(self.device):
array = jax.numpy.array(
input, dtype=utils.dtype_map[_dtype], device=self.device
)
array = jax.numpy.array(input, dtype=utils.dtype_map[_dtype])

if self._parallel_manager is not None:
array = self._parallel_manager.parallelize(array, device_mesh)
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def match(self, other: BaseData[T]) -> Updates:
self.differentiable = other.differentiable = is_diff

if self.is_valued or other.is_valued:
valued, non_valued = (self, other) if self.is_valued else (other, self)
valued, non_valued = (other, self) if other.is_valued else (self, other)
updates |= non_valued.set_value(valued.value)
if non_valued is other:
if other.is_tensor:
Expand Down
45 changes: 0 additions & 45 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,14 @@
from ...utils.utils import OrderedSet
from ..common import (
NOT_AVAILABLE,
NOT_GIVEN,
TBD,
Connection,
ConnectionData,
Connections,
ConnectionType,
Constraint,
ConstraintFunctionType,
ConstraintSolver,
ExtendTemplate,
IOHyperEdge,
IOKey,
MainValueInstance,
MainValueType,
NestedListType,
NotAvailable,
Expand All @@ -46,8 +41,6 @@
ShapeTemplateType,
ShapeType,
Tensor,
TensorValueType,
ToBeDetermined,
UniadicRecord,
Updates,
UpdateType,
Expand Down Expand Up @@ -95,50 +88,12 @@ class BaseModel(abc.ABC):
factory_args: dict[str, Any] = {}

def __call__(self, **kwargs: ConnectionType) -> ExtendInfo:
for key, val in self.factory_inputs.items():
if val is not TBD:
if key not in kwargs or (con := kwargs[key]) is NOT_GIVEN:
kwargs[key] = val # type: ignore
continue
match con:
case Connection():
kwargs[key] = IOKey(value=val, connections={con})
# TODO: Maybe we could check con's value if matches with val
case item if isinstance(item, MainValueInstance) and con != val:
raise ValueError(
f"Given value {con} for local key: '{key}' "
f"has already being set to {val}!"
)
case str():
kwargs[key] = IOKey(name=con, value=val, expose=False)
case IOKey():
if con.data.value is not TBD and con.data.value != val:
raise ValueError(
f"Given IOKey for local key: '{key}' is not valid!"
)
else:
kwargs[key] = IOKey(
name=con.name,
expose=con.expose,
connections=con.connections,
type=con.data.type,
shape=con.data.shape,
value=val,
)
case ExtendTemplate():
raise ValueError(
"Multi-write detected for a valued "
f"local key: '{key}' is not valid!"
)
return ExtendInfo(self, kwargs)

def __init__(self, name: str | None = None, enforce_jit: bool = True) -> None:
self.parent: BaseModel | None = (
None # TODO: maybe set it only to PrimitiveModel / Model.
)
self.factory_inputs: dict[
str, TensorValueType | MainValueType | ToBeDetermined
] = {}
self.assigned_shapes: list[ShapesType] = []
self.assigned_constraints: list[dict[str, str | list[str]]] = []
self.conns = Connections()
Expand Down
Loading

0 comments on commit a0e6b1b

Please sign in to comment.