Skip to content

Commit

Permalink
chore: Linted Physical folder (synnada-ai#138)
Browse files Browse the repository at this point in the history
Co-authored-by: kberat-synnada <[email protected]>
  • Loading branch information
mehmetozsoy-synnada and kberat-synnada authored Jan 2, 2025
1 parent 3a1fc42 commit 941bb50
Show file tree
Hide file tree
Showing 39 changed files with 397 additions and 355 deletions.
2 changes: 1 addition & 1 deletion benchmarks/speed_benchmarks/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def setup(self):
def __call__(self, inputs):
x = inputs
for lyr, actv in zip(self.layers, self.jax_activations, strict=False):
x = lyr(x) # type: ignore
x = lyr(x)
x = actv(x) # type: ignore
return x

Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/linear_regression_jax_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@
for i in range(num_epochs):
outputs, gradients = pm.evaluate_all(params)
updates, opt_state = optimizer.update(gradients, opt_state)
params = optax.apply_updates(params, updates) # type: ignore
params = optax.apply_updates(params, updates)
print(f"Epoch: {i} / {num_epochs} -> ", outputs["final_cost"])
10 changes: 5 additions & 5 deletions examples/model_api/variable_length_many_to_one_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
target_end = int(input_end + target_lengths[idx])

# NOTE: Pylance sees int, int type arguments but throws an error.
single_input = backend.arange(start, input_end).reshape(-1, input_dim) # type: ignore
single_target = backend.arange(input_end, target_end).reshape(-1, output_dim) # type: ignore
single_input = backend.arange(start, input_end).reshape(-1, input_dim)
single_target = backend.arange(input_end, target_end).reshape(-1, output_dim)

single_data = (single_input, single_target)
train_data.append(single_data)
Expand Down Expand Up @@ -150,7 +150,7 @@
# Prepare the test input data.
test_input = backend.arange(
starting_number,
starting_number + inference_max_input, # type: ignore
starting_number + inference_max_input,
).reshape(-1, input_dim)

# Prepare the test data.
Expand All @@ -172,7 +172,7 @@

# Prepare target values.
test_target_values = backend.arange(
starting_number + inference_max_input, # type: ignore
starting_number + inference_max_input,
starting_number + inference_max_input + inference_max_target_length,
)

Expand Down Expand Up @@ -204,4 +204,4 @@
)

# Measure test error.
error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() # type: ignore
error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum()
15 changes: 8 additions & 7 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,27 @@ def __init__(self, precision: int = 32, device: str = "cpu") -> None:
# setattr(self, key, value)

@property
def precision(self):
def precision(self) -> int:
return self._precision

#!!
@property
def device(self):
def device(self) -> Any:
return self._device

def get_device(self):
def get_device(self) -> str:
return self._device

@property
def inf(self) -> DataType | float:
raise NotImplementedError("inf is not implemented")

@property
def pi(self):
def pi(self) -> float:
return math.pi

@property
def e(self):
def e(self) -> float:
return math.e

@property
Expand Down Expand Up @@ -104,7 +105,7 @@ def to_device(
def block_until_ready(self, data: DataType) -> DataType | None:
raise RuntimeError("Backend does not support block_until_ready method!")

def empty_cache(self): # noqa: B027
def empty_cache(self) -> None: # noqa: B027
pass
# print("Warning: empty_cache is not supported!")

Expand All @@ -126,7 +127,7 @@ def cast(self, value: Any) -> Any:

return value

def __del__(self):
def __del__(self) -> None:
self.empty_cache()

@overload
Expand Down
4 changes: 2 additions & 2 deletions mithril/backends/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def run_callable(self, *primals: Any, fn_name: str) -> dict[str, Any]:
@abstractmethod
def parallelize(
self, tensor: DataType, device_mesh: tuple[int, ...] | None = None
) -> dict[str, Any]:
) -> DataType:
raise NotImplementedError()

def clean_up(self):
def clean_up(self) -> None:
self.callables = dict()
self.device_mesh = None
self.n_devices = -1
98 changes: 50 additions & 48 deletions mithril/backends/with_autograd/common_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,80 +63,80 @@
]


def greater(left: DataType, right: DataType):
def greater(left: DataType, right: DataType) -> DataType:
return left > right


def greater_equal(left: DataType, right: DataType):
def greater_equal(left: DataType, right: DataType) -> DataType:
return left >= right


def less(left: DataType, right: DataType):
def less(left: DataType, right: DataType) -> DataType:
return left < right


def less_equal(left: DataType, right: DataType):
def less_equal(left: DataType, right: DataType) -> DataType:
return left <= right


def equal(left: DataType, right: DataType):
return left == right
def equal(left: DataType, right: DataType) -> DataType:
return left == right # type: ignore


def not_equal(left: DataType, right: DataType):
return left != right
def not_equal(left: DataType, right: DataType) -> DataType:
return left != right # type: ignore


def logical_not(input: DataType):
def logical_not(input: DataType) -> DataType:
return ~input


def logical_or(left: DataType, right: DataType):
return left | right
def logical_or(left: DataType, right: DataType) -> DataType:
return left | right # type: ignore


def logical_and(left: DataType, right: DataType):
return left & right
def logical_and(left: DataType, right: DataType) -> DataType:
return left & right # type: ignore


def matrix_multiplication(left: DataType, right: DataType):
return left @ right
def matrix_multiplication(left: DataType, right: DataType) -> DataType:
return left @ right # type: ignore


def add(left: DataType, right: DataType):
return left + right
def add(left: DataType, right: DataType) -> DataType:
return left + right # type: ignore


def subtract(left: DataType, right: DataType):
return left - right
def subtract(left: DataType, right: DataType) -> DataType:
return left - right # type: ignore


def multiplication(left: DataType, right: DataType):
return left * right
def multiplication(left: DataType, right: DataType) -> DataType:
return left * right # type: ignore


def divide(numerator: DataType, denominator: DataType):
return numerator / denominator
def divide(numerator: DataType, denominator: DataType) -> DataType:
return numerator / denominator # type: ignore


def floor_divide(numerator: DataType, denominator: DataType):
return numerator // denominator
def floor_divide(numerator: DataType, denominator: DataType) -> DataType:
return numerator // denominator # type: ignore


def shift_left(input: DataType, shift: DataType):
return input << shift
def shift_left(input: DataType, shift: DataType) -> DataType:
return input << shift # type: ignore


def shift_right(input: DataType, shift: DataType):
return input >> shift
def shift_right(input: DataType, shift: DataType) -> DataType:
return input >> shift # type: ignore


def power(base: DataType, exponent: DataType):
return base**exponent
def power(base: DataType, exponent: DataType) -> DataType:
return base**exponent # type: ignore


def squared_error(input: DataType, target: DataType):
return (input - target) ** 2
def squared_error(input: DataType, target: DataType) -> DataType:
return (input - target) ** 2 # type: ignore


def minus(input: DataType) -> DataType:
Expand All @@ -148,18 +148,18 @@ def transpose(
) -> DataType:
if not axes:
return input.T
return input.transpose(*axes)
return input.transpose(*axes) # type: ignore


def swapaxes(input: DataType, axis1: int, axis2: int):
def swapaxes(input: DataType, axis1: int, axis2: int) -> DataType:
return input.swapaxes(axis1, axis2)


def square(input: DataType):
return input * input
def square(input: DataType) -> DataType:
return input * input # type: ignore


def buffer(input: DataType):
def buffer(input: DataType) -> DataType:
return input


Expand All @@ -168,27 +168,29 @@ def permute_tensor(input: DataType, indices: DataType) -> DataType:


def reshape(input: DataType, shape: tuple[int, ...]) -> DataType:
return input.reshape(shape)
return input.reshape(shape) # type: ignore


def item(input: DataType) -> int | float | bool:
return input.item() # type: ignore


def tensor_item(input: DataType, index: int | slice | tuple[int | slice, ...]):
return input[index]
def tensor_item(
input: DataType, index: int | slice | tuple[int | slice, ...]
) -> DataType:
return input[index] # type: ignore


def primitive_slice(start: int | None, stop: int | None, step: int | None):
def primitive_slice(start: int | None, stop: int | None, step: int | None) -> slice:
return slice(start, stop, step)


def length(input: DataType) -> int:
return len(input)


def cartesian_diff(left: DataType, right: DataType):
return left[:, None, :] - right[None, :, :]
def cartesian_diff(left: DataType, right: DataType) -> DataType:
return left[:, None, :] - right[None, :, :] # type: ignore


def primitive_embedding(input: DataType, weight: DataType) -> DataType:
Expand Down Expand Up @@ -218,11 +220,11 @@ def union(*inputs: int | float | tuple[int | float, ...]) -> tuple[int | float,
return result


def to_tuple(*args: tuple[int | float | bool, ...]):
def to_tuple(*args: int | float | bool) -> tuple[int | float | bool, ...]:
return tuple(args)


def to_list(*args: tuple[int | float | bool, ...]):
def to_list(*args: int | float | bool) -> list[int | float | bool]:
return list(args)


Expand Down Expand Up @@ -291,7 +293,7 @@ def padding_converter_2d(
def stride_converter(
input: int | PaddingType | tuple[int, int] | None,
kernel_size: int | tuple[int, int],
):
) -> int | tuple[int, int] | PaddingType:
if input is None:
return kernel_size
else:
Expand All @@ -303,7 +305,7 @@ def tuple_converter(
| PaddingType
| tuple[int, int]
| tuple[tuple[int, int], tuple[int, int]],
):
) -> tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] | PaddingType:
if isinstance(input, int):
return (input, input)
else:
Expand Down
20 changes: 11 additions & 9 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def is_manualgrad(self) -> bool:
return False

@property
def inf(self):
def inf(self) -> float:
return jax.numpy.inf

@property
def nan(self):
def nan(self) -> float:
return jax.numpy.nan

def get_backend_array_type(self):
def get_backend_array_type(self) -> type[jax.Array]:
return jax.Array

@property
Expand All @@ -98,7 +98,7 @@ def DataType(self): # noqa: N802
return utils.ArrayType

@staticmethod
def get_available_devices():
def get_available_devices() -> list[str]:
"""Static method to get a list of available devices.
Parameters
Expand All @@ -112,7 +112,7 @@ def get_available_devices():
def register_primitive(fn: Callable[..., Any]) -> None:
JaxBackend.registered_primitives[fn.__name__] = fn

def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
self.seed = seed
self.prng_key = jax.random.PRNGKey(seed)

Expand Down Expand Up @@ -145,23 +145,23 @@ def block_until_ready(self, data: jax.Array) -> jax.Array | None:

def register_callable(
self, fn: Callable[..., Any], fn_name: str, jit: bool = False
):
) -> None:
assert (
self._parallel_manager is not None
), "Parallel manager is not initialized!"

fn_name = str(id(self)) + fn_name
return self._parallel_manager.register_callable(fn, fn_name, jit)

def _run_callable(self, *primals: jax.Array, fn_name: str):
def _run_callable(self, *primals: jax.Array, fn_name: str) -> Any:
assert (
self._parallel_manager is not None
), "Parallel manager is not initialized!"

fn_name = str(id(self)) + fn_name
return self._parallel_manager.run_callable(*primals, fn_name=fn_name)

def _create_parallel(self, device_mesh: tuple[int, ...]):
def _create_parallel(self, device_mesh: tuple[int, ...]) -> None:
self._parallel_manager = JaxParallel(math.prod(device_mesh), self._device)

def array(
Expand Down Expand Up @@ -538,7 +538,9 @@ def multinomial(

return samples

def jit(self, *args: Any, **kwargs: Any):
def jit( # type: ignore[override]
self, *args: Any, **kwargs: Any
) -> Callable[..., jax.Array | tuple[jax.Array, ...]] | dict[str, jax.Array]:
return jax.jit(*args, **kwargs)

def grad(
Expand Down
Loading

0 comments on commit 941bb50

Please sign in to comment.