Skip to content

Commit

Permalink
MyPy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jul 10, 2024
1 parent 684027d commit 9d3f5eb
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 16 deletions.
2 changes: 1 addition & 1 deletion opt_einsum/backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _import_func(func: str, backend: str, default: Any = None) -> Any:
}

try:
import numpy as np
import numpy as np # type: ignore

_cached_funcs[("tensordot", "numpy")] = np.tensordot
_cached_funcs[("transpose", "numpy")] = np.transpose
Expand Down
4 changes: 2 additions & 2 deletions opt_einsum/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def _get_jax_and_to_jax():
global _JAX
if _JAX is None:
import jax
import jax # type: ignore

@to_backend_cache_wrap
@jax.jit
Expand All @@ -29,7 +29,7 @@ def build_expression(_, expr): # pragma: no cover
jax_expr = jax.jit(expr._contract)

def jax_contract(*arrays):
import numpy as np
import numpy as np # type: ignore

return np.asarray(jax_expr(arrays))

Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/backends/object_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType:
out : numpy.ndarray
The output tensor, with ``dtype=object``.
"""
import numpy as np
import numpy as np # type: ignore

# when called by ``opt_einsum`` we will always be given a full eq
lhs, output = eq.split("->")
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/backends/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _get_tensorflow_and_device():
global _CACHED_TF_DEVICE

if _CACHED_TF_DEVICE is None:
import tensorflow as tf
import tensorflow as tf # type: ignore

try:
eager = tf.executing_eagerly()
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/backends/theano.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@to_backend_cache_wrap(constants=True)
def to_theano(array, constant=False):
"""Convert a numpy array to ``theano.tensor.TensorType`` instance."""
import theano
import theano # type: ignore

if has_array_interface(array):
if constant:
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _get_torch_and_device():
global _TORCH_HAS_TENSORDOT

if _TORCH_DEVICE is None:
import torch
import torch # type: ignore

device = "cuda" if torch.cuda.is_available() else "cpu"
_TORCH_DEVICE = torch, device
Expand Down
5 changes: 2 additions & 3 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""A functionally equivalent parser of the numpy.einsum input parser."""

import itertools
from collections.abc import Sequence
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Sequence, Tuple

from opt_einsum.typing import ArrayType, TensorShapeType

Expand Down Expand Up @@ -219,7 +218,7 @@ def possibly_convert_to_numpy(x: Any) -> Any:
"""
if not hasattr(x, "shape"):
try:
import numpy as np
import numpy as np # type: ignore
except ModuleNotFoundError:
raise ModuleNotFoundError(
"numpy is required to convert non-array objects to arrays. This function will be deprecated in the future."
Expand Down
4 changes: 2 additions & 2 deletions opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
try:
# needed so tensorflow doesn't allocate all gpu mem
try:
from tensorflow import ConfigProto
from tensorflow import ConfigProto # type: ignore
from tensorflow import Session as TFSession
except ImportError:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import ConfigProto # type: ignore
from tensorflow.compat.v1 import Session as TFSession
_TF_CONFIG = ConfigProto()
_TF_CONFIG.gpu_options.allow_growth = True
Expand Down
4 changes: 2 additions & 2 deletions opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests.
"""

from typing import Any
from typing import Any, List

import pytest

Expand All @@ -12,7 +12,7 @@
np = pytest.importorskip("numpy")


def build_views(string: str) -> list[ArrayType]:
def build_views(string: str) -> List[ArrayType]:
"""Builds random numpy arrays for testing by using a fixed size dictionary and an input string."""

chars = "abcdefghij"
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
cupy_if_found = pytest.param("cupy", marks=[pytest.mark.skip(reason="CuPy not installed.")]) # type: ignore

try:
import torch # noqa
import torch # type: ignore # noqa

torch_if_found = "torch"
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ branch = true
relative_files = true

[[tool.mypy.overrides]]
module = "cupy.*, jax.*, theano.*, tensorflow.*, torch.*"
module = "cupy.*, jax.*, numpy.*, theano.*, tensorflow.*, torch.*"
ignore_missing_imports = true

0 comments on commit 9d3f5eb

Please sign in to comment.