Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 21, 2024
1 parent 8997442 commit 872a344
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 19 deletions.
10 changes: 4 additions & 6 deletions src/arpes/plotting/fermi_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def magnify_circular_regions_plot( # noqa: PLR0913
radius: float = 0.05,
# below this two can be treated as kwargs?
cmap: Colormap | ColorType = "viridis",
color: ColorType | None = None,
edgecolor: ColorType = "red",
color: ColorType | list[ColorType] = "blue",
edgecolor: ColorType | list[ColorType] = "red",
out: str | Path = "",
ax: Axes | None = None,
**kwargs: tuple[float, float],
Expand Down Expand Up @@ -122,14 +122,13 @@ def magnify_circular_regions_plot( # noqa: PLR0913
clim = list(mesh.get_clim())
clim[1] = clim[1] / mag

mask = np.zeros(shape=(len(data_arr.values.ravel()),))
pts = np.zeros(
shape=(
len(data_arr.values.ravel()),
2,
),
)
mask = mask > 0
mask = np.zeros(shape=len(data_arr.values.ravel())) > 0

raveled = data_arr.G.ravel()
pts[:, 0] = raveled[data_arr.dims[0]]
Expand All @@ -150,7 +149,6 @@ def magnify_circular_regions_plot( # noqa: PLR0913

if not isinstance(color, list):
color = [color for _ in range(len(magnified_points))]
assert isinstance(color, list)

pts[:, 1] = (pts[:, 1]) / (xlim[1] - xlim[0])
pts[:, 0] = (pts[:, 0]) / (ylim[1] - ylim[0])
Expand All @@ -168,7 +166,7 @@ def magnify_circular_regions_plot( # noqa: PLR0913
linewidth=2,
zorder=4,
)
patchfake = matplotlib.patches.Ellipse([point[1], point[0]], radius, radius)
patchfake = matplotlib.patches.Ellipse((point[1], point[0]), radius, radius)
ax.add_patch(patch)
mask = np.logical_or(mask, patchfake.contains_points(pts))

Expand Down
9 changes: 5 additions & 4 deletions src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import re
import warnings
from collections import Counter
from collections.abc import Generator, Hashable, Iterable, Iterator, Sequence
from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Sequence
from datetime import UTC
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type
from typing import TYPE_CHECKING, Any, Literal, Unpack

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -37,6 +37,7 @@
if TYPE_CHECKING:
from _typeshed import Incomplete
from lmfit.model import Model
from matplotlib.collections import PathCollection
from matplotlib.font_manager import FontProperties
from matplotlib.image import AxesImage
from matplotlib.typing import ColorType
Expand Down Expand Up @@ -576,7 +577,7 @@ def lineplot_arr(

xs = None
if arr is not None:
fn = plt.plot
fn: Callable[..., list[Line2D]] | Callable[..., PathCollection] = plt.plot
if method == "scatter":
fn = plt.scatter

Expand Down Expand Up @@ -613,7 +614,7 @@ def plot_arr(
if n_dims == TWO_DIMENSION:
quad = None
if arr is not None:
ax, quad = imshow_arr(arr, ax=ax, over=over, **kwargs)
fig, quad = imshow_arr(arr, ax=ax, over=over, **kwargs)
if mask is not None:
over = quad if over is None else over
imshow_mask(mask, ax=ax, over=over, **kwargs)
Expand Down
17 changes: 10 additions & 7 deletions src/arpes/utilities/selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def ravel_from_mask(data: DataType, mask: XrTypes) -> DataType:
Returns:
Raveled data with masked points removed.
"""
return data.stack(stacked=mask.dims).where(mask.stack(stacked=mask.dims), drop=True)
return data.stack(stacked=list(mask.dims)).where(mask.stack(stacked=list(mask.dims)), drop=True)


def unravel_from_mask(
Expand All @@ -65,27 +65,30 @@ def unravel_from_mask(
dest = template * 0 + 1
dest_mask = np.logical_not(
np.isnan(
template.stack(stacked=template.dims).where(mask.stack(stacked=template.dims)).values,
template.stack(stacked=list(template.dims))
.where(mask.stack(stacked=list(template.dims)))
.values,
),
)
dest = (dest * default).stack(stacked=template.dims)
dest = (dest * default).stack(stacked=list(template.dims))
dest.values[dest_mask] = values
return dest.unstack("stacked")


def _normalize_point(
data: xr.DataArray,
around: dict[str, xr.DataArray] | xr.Dataset,
**kwargs: Incomplete,
around: dict[str, xr.DataArray] | xr.Dataset | None,
**kwargs: NDArray[np.float_] | float,
) -> dict[str, xr.DataArray]:
collected_kwargs = {k: kwargs[k] for k in data.dims if k in kwargs}
collected_kwargs = {k: kwargs[str(k)] for k in data.dims if k in kwargs}

if around:
if isinstance(around, xr.Dataset):
around = unwrap_xarray_dict({d: around[d] for d in data.dims})
around = unwrap_xarray_dict({str(d): around[d] for d in data.dims})
else:
around = collected_kwargs

assert isinstance(around, dict)
assert set(around.keys()) == set(data.dims)
return around

Expand Down
2 changes: 1 addition & 1 deletion src/arpes/utilities/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def unwrap_xarray_item(item: xr.DataArray) -> xr.DataArray | float:

def unwrap_xarray_dict(
input_dict: dict[str, xr.DataArray],
) -> dict[str, xr.DataArray | NDArray[np.float_] | float]:
) -> dict[str, xr.DataArray | float]:
"""Returns the attributes as unwrapped values rather than item() instances.
Useful for unwrapping coordinate dicts where the values might be a bare type:
Expand Down
1 change: 0 additions & 1 deletion src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
PColorMeshKwargs,
SampleInfo,
ScanInfo,
Spectrometer,
XrTypes,
)
from .provenance import Provenance
Expand Down

0 comments on commit 872a344

Please sign in to comment.