Skip to content

Commit

Permalink
v3.4 Release notes and removes NumPy kwargs (#239)
Browse files Browse the repository at this point in the history
* Removes explicit kwargs

* Adds 3.4 changelog

* Pins MyPy

* Type fix

* Fixes typo in

* Allows booleans as optimize, closes #219

* Update docs/changelog.md

Co-authored-by: Jane (Yuan) Xu <[email protected]>

* Update docs/changelog.md

Co-authored-by: Jane (Yuan) Xu <[email protected]>

---------

Co-authored-by: Jane (Yuan) Xu <[email protected]>
  • Loading branch information
dgasmith and janeyx99 authored Sep 26, 2024
1 parent 1992f4a commit c15aec2
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 63 deletions.
2 changes: 1 addition & 1 deletion devtools/conda-envs/full-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
4 changes: 2 additions & 2 deletions devtools/conda-envs/min-deps-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
- ruff ==0.6.*
2 changes: 1 addition & 1 deletion devtools/conda-envs/min-ver-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
2 changes: 1 addition & 1 deletion devtools/conda-envs/torch-only-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
29 changes: 29 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
Changelog
=========

## 3.4.0 / 2024-09-XX

NumPy has been removed from `opt_einsum` as a dependency allowing for more flexible installs.

**New Features**

- [\#160](https://github.com/dgasmith/opt_einsum/pull/160) Migrates docs to MkDocs Material and GitHub pages hosting.
- [\#161](https://github.com/dgasmith/opt_einsum/pull/161) Adds Python type annotations to the code base.
- [\#204](https://github.com/dgasmith/opt_einsum/pull/204) Removes NumPy as a hard dependency.

**Enhancements**

- [\#154](https://github.com/dgasmith/opt_einsum/pull/154) Prevents an infinite recursion error when the `memory_limit` was set very low for the `dp` algorithm.
- [\#155](https://github.com/dgasmith/opt_einsum/pull/155) Adds flake8 spell check to the doc strings
- [\#159](https://github.com/dgasmith/opt_einsum/pull/159) Migrates to GitHub actions for CI.
- [\#174](https://github.com/dgasmith/opt_einsum/pull/174) Prevents double contracts of floats in dynamic paths.
- [\#196](https://github.com/dgasmith/opt_einsum/pull/196) Allows `backend=None` which is equivalent to `backend='auto'`
- [\#208](https://github.com/dgasmith/opt_einsum/pull/208) Switches to `ConfigParser` insetad of `SafeConfigParser` for Python 3.12 compatability.
- [\#228](https://github.com/dgasmith/opt_einsum/pull/228) `backend='jaxlib'` is now an alias for the `jax` library
- [\#237](https://github.com/dgasmith/opt_einsum/pull/237) Switches to `ruff` for formatting and linting.
- [\#238](https://github.com/dgasmith/opt_einsum/pull/238) Removes `numpy`-specific keyword args from being explicitly defined in `contract` and uses `**kwargs` instead.

**Bug Fixes**

- [\#195](https://github.com/dgasmith/opt_einsum/pull/195) Fixes a bug where `dp` would not work for scalar-only contractions.
- [\#200](https://github.com/dgasmith/opt_einsum/pull/200) Fixes a bug where `parse_einsum_input` would not correctly respect shape-only contractions.
- [\#222](https://github.com/dgasmith/opt_einsum/pull/222) Fixes an erorr in `parse_einsum_input` where an output subscript specified multiple times was not correctly caught.
- [\#229](https://github.com/dgasmith/opt_einsum/pull/229) Fixes a bug where empty contraction lists in `PathInfo` would cause an error.

## 3.3.0 / 2020-07-19

Adds a `object` backend for optimized contractions on arbitrary Python objects.
Expand Down
66 changes: 19 additions & 47 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@

## Common types

_OrderKACF = Literal[None, "K", "A", "C", "F"]

_Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"]
_MemoryLimit = Union[None, int, Decimal, Literal["max_input"]]


Expand Down Expand Up @@ -284,7 +281,7 @@ def contract_path(
#> 5 defg,hd->efgh efgh->efgh
```
"""
if optimize is True:
if (optimize is True) or (optimize is None):
optimize = "auto"

# Hidden option, only einsum should call this
Expand Down Expand Up @@ -344,9 +341,11 @@ def contract_path(
naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)

# Compute the path
if not isinstance(optimize, (str, paths.PathOptimizer)):
if optimize is False:
path_tuple: PathType = [tuple(range(num_ops))]
elif not isinstance(optimize, (str, paths.PathOptimizer)):
# Custom path supplied
path_tuple: PathType = optimize # type: ignore
path_tuple = optimize # type: ignore
elif num_ops <= 2:
# Nothing to be optimized
path_tuple = [tuple(range(num_ops))]
Expand Down Expand Up @@ -479,9 +478,6 @@ def contract(
subscripts: str,
*operands: ArrayType,
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
casting: _Casting = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
Expand All @@ -495,9 +491,6 @@ def contract(
subscripts: ArrayType,
*operands: Union[ArrayType, Collection[int]],
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
casting: _Casting = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
Expand All @@ -510,9 +503,6 @@ def contract(
subscripts: Union[str, ArrayType],
*operands: Union[ArrayType, Collection[int]],
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
use_blas: bool = True,
optimize: OptimizeKind = True,
memory_limit: _MemoryLimit = None,
Expand All @@ -527,9 +517,6 @@ def contract(
subscripts: Specifies the subscripts for summation.
*operands: These are the arrays for the operation.
out: A output array in which set the resulting output.
dtype: The dtype of the given contraction, see np.einsum.
order: The order of the resulting contraction, see np.einsum.
casting: The casting procedure for operations of different dtype, see np.einsum.
use_blas: Do you use BLAS for valid operations, may use extra memory for more intermediates.
optimize:- Choose the type of path the contraction will be optimized with
- if a list is given uses this as the path.
Expand All @@ -551,11 +538,12 @@ def contract(
- `'branch-2'` An even more restricted version of 'branch-all' that
only searches the best two options at each step. Scales exponentially
with the number of terms in the contraction.
- `'auto'` Choose the best of the above algorithms whilst aiming to
- `'auto', None, True` Choose the best of the above algorithms whilst aiming to
keep the path finding time below 1ms.
- `'auto-hq'` Aim for a high quality contraction, choosing the best
of the above algorithms whilst aiming to keep the path finding time
below 1sec.
- `False` will not optimize the contraction.
memory_limit:- Give the upper bound of the largest intermediate tensor contract will build.
- None or -1 means there is no limit.
Expand Down Expand Up @@ -586,21 +574,18 @@ def contract(
performed optimally. When NumPy is linked to a threaded BLAS, potential
speedups are on the order of 20-100 for a six core machine.
"""
if optimize is True:
if (optimize is True) or (optimize is None):
optimize = "auto"

operands_list = [subscripts] + list(operands)
einsum_kwargs = {"out": out, "dtype": dtype, "order": order, "casting": casting}

# If no optimization, run pure einsum
if optimize is False:
return _einsum(*operands_list, **einsum_kwargs)
return _einsum(*operands_list, out=out, **kwargs)

# Grab non-einsum kwargs
gen_expression = kwargs.pop("_gen_expression", False)
constants_dict = kwargs.pop("_constants_dict", {})
if len(kwargs):
raise TypeError(f"Did not understand the following kwargs: {kwargs.keys()}")

if gen_expression:
full_str = operands_list[0]
Expand All @@ -613,11 +598,9 @@ def contract(

# check if performing contraction or just building expression
if gen_expression:
return ContractExpression(full_str, contraction_list, constants_dict, dtype=dtype, order=order, casting=casting)
return ContractExpression(full_str, contraction_list, constants_dict, **kwargs)

return _core_contract(
operands, contraction_list, backend=backend, out=out, dtype=dtype, order=order, casting=casting
)
return _core_contract(operands, contraction_list, backend=backend, out=out, **kwargs)


@lru_cache(None)
Expand Down Expand Up @@ -651,9 +634,7 @@ def _core_contract(
backend: Optional[str] = "auto",
evaluate_constants: bool = False,
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
**kwargs: Any,
) -> ArrayType:
"""Inner loop used to perform an actual contraction given the output
from a ``contract_path(..., einsum_call=True)`` call.
Expand Down Expand Up @@ -703,7 +684,7 @@ def _core_contract(
axes = ((), ())

# Contract!
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend, **kwargs)

# Build a new view if needed
if (tensor_result != results_index) or handle_out:
Expand All @@ -718,9 +699,7 @@ def _core_contract(
out_kwarg: Union[None, ArrayType] = None
if handle_out:
out_kwarg = out
new_view = _einsum(
einsum_str, *tmp_operands, backend=backend, dtype=dtype, order=order, casting=casting, out=out_kwarg
)
new_view = _einsum(einsum_str, *tmp_operands, backend=backend, out=out_kwarg, **kwargs)

# Append new items and dereference what we can
operands.append(new_view)
Expand Down Expand Up @@ -768,15 +747,11 @@ def __init__(
contraction: str,
contraction_list: ContractionListType,
constants_dict: Dict[int, ArrayType],
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
**kwargs: Any,
):
self.contraction_list = contraction_list
self.dtype = dtype
self.order = order
self.casting = casting
self.contraction = format_const_einsum_str(contraction, constants_dict.keys())
self.contraction_list = contraction_list
self.kwargs = kwargs

# need to know _full_num_args to parse constants with, and num_args to call with
self._full_num_args = contraction.count(",") + 1
Expand Down Expand Up @@ -844,9 +819,7 @@ def _contract(
out=out,
backend=backend,
evaluate_constants=evaluate_constants,
dtype=self.dtype,
order=self.order,
casting=self.casting,
**self.kwargs,
)

def _contract_with_conversion(
Expand Down Expand Up @@ -943,8 +916,7 @@ def __str__(self) -> str:
for i, c in enumerate(self.contraction_list):
s.append(f"\n {i + 1}. ")
s.append(f"'{c[2]}'" + (f" [{c[-1]}]" if c[-1] else ""))
kwargs = {"dtype": self.dtype, "order": self.order, "casting": self.casting}
s.append(f"\neinsum_kwargs={kwargs}")
s.append(f"\neinsum_kwargs={self.kwargs}")
return "".join(s)


Expand Down
9 changes: 7 additions & 2 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,14 @@ def branch(
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
**optimizer_kwargs: Dict[str, Any],
nbranch: Optional[int] = None,
cutoff_flops_factor: int = 4,
minimize: str = "flops",
cost_fn: str = "memory-removed",
) -> PathType:
optimizer = BranchBound(**optimizer_kwargs) # type: ignore
optimizer = BranchBound(
nbranch=nbranch, cutoff_flops_factor=cutoff_flops_factor, minimize=minimize, cost_fn=cost_fn
)
return optimizer(inputs, output, size_dict, memory_limit)


Expand Down
13 changes: 13 additions & 0 deletions opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")


tests = [
# Test scalar-like operations
"a,->a",
Expand Down Expand Up @@ -99,6 +100,18 @@
]


@pytest.mark.parametrize("optimize", (True, False, None))
def test_contract_plain_types(optimize: OptimizeKind) -> None:
expr = "ij,jk,kl->il"
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]

path = contract_path(expr, *ops, optimize=optimize)
assert len(path) == 2

result = contract(expr, *ops, optimize=optimize)
assert result.shape == (2, 2)


@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare(optimize: OptimizeKind, string: str) -> None:
Expand Down
9 changes: 1 addition & 8 deletions opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,9 @@ def test_value_errors(contract_fn: Any) -> None:
# broadcasting to new dimensions must be enabled explicitly
with pytest.raises(ValueError):
contract_fn("i", np.arange(6).reshape(2, 3))
if contract_fn is contract:
# contract_path does not have an `out` parameter
with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2))

with pytest.raises(TypeError):
contract_fn("i->i", [[0, 1], [0, 1]], bad_kwarg=True)

with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], memory_limit=-1)
contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_flop_cost() -> None:


def test_bad_path_option() -> None:
with pytest.raises(TypeError):
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore


Expand Down

0 comments on commit c15aec2

Please sign in to comment.