Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Oct 11, 2023
1 parent 9c976fa commit c73a769
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
6 changes: 3 additions & 3 deletions arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
if TYPE_CHECKING:
from collections.abc import Callable

import lmfit as lf
from _typeshed import Incomplete
from lmfit import lf
from numpy.typing import NDArray

from arpes._typing import DataType
Expand Down Expand Up @@ -200,9 +200,9 @@ def bootstrap_counts(

resampled_arr = np.stack([s.values for s in resampled_sets], axis=0)
std = np.std(resampled_arr, axis=0)
std = xr.DataArray(std, data.coords, data.dims)
std = xr.DataArray(std, data.coords, tuple(data.dims))
mean = np.mean(resampled_arr, axis=0)
mean = xr.DataArray(mean, data.coords, data.dims)
mean = xr.DataArray(mean, data.coords, tuple(data.dims))

data_vars = {}
data_vars[name] = mean
Expand Down
4 changes: 3 additions & 1 deletion arpes/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

from _typeshed import Incomplete

from arpes._typing import WORKSPACETYPE


def attach_id(data: DataType) -> None:
"""Ensures that an ID is attached to a piece of data, if it does not already exist.
Expand Down Expand Up @@ -158,7 +160,7 @@ def func_wrapper(*args: Incomplete, **kwargs: Incomplete) -> Incomplete:

path = plot_fn(*args, **kwargs)
if isinstance(path, str) and Path(path).exists():
workspace = arpes.config.CONFIG["WORKSPACE"]
workspace: WORKSPACETYPE = arpes.config.CONFIG["WORKSPACE"]

with contextlib.suppress(TypeError, KeyError):
workspace = workspace["name"]
Expand Down
6 changes: 3 additions & 3 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,9 @@ def short_history(self, key: str = "by") -> list:

def _calculate_symmetry_points(
self,
symmetry_points: dict,
symmetry_points: dict[str, list[float]],
epsilon: float = 0.01,
) -> tuple:
) -> tuple[dict[str, list[float]], dict[str, list[float]]]:
# For each symmetry point, we need to determine if it is projected or not
# if it is projected, we need to calculate its projected coordinates
"""[TODO:summary].
Expand Down Expand Up @@ -670,7 +670,7 @@ def symmetry_points(
*,
raw: bool = False,
**kwargs: float,
) -> dict | tuple:
) -> dict[str, list[float]] | tuple[dict[str, list[float]], dict[str, list[float]]]:
"""[TODO:summary].
Args:
Expand Down

0 comments on commit c73a769

Please sign in to comment.