Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 15, 2024
1 parent f7978fd commit 48f4dbb
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 42 deletions.
74 changes: 66 additions & 8 deletions src/arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
from __future__ import annotations

import uuid
from typing import (
TYPE_CHECKING,
Literal,
Required,
TypeAlias,
TypedDict,
TypeVar,
)
from typing import TYPE_CHECKING, Any, Literal, Required, TypeAlias, TypedDict, TypeVar

import xarray as xr

Expand Down Expand Up @@ -334,6 +327,7 @@ class Spectrometer(AnalyzerInfo, Coordinates, DAQInfo, total=False):
mstar: float
dof_type: dict[str, list[str]]
length: float
probe_linewidth: float


class ExperimentInfo(
Expand Down Expand Up @@ -379,6 +373,70 @@ class QPushButtonArgs(TypedDict, total=False):
#


class Line2DProperty(TypedDict, total=False):
agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]]
alpha: float | None
animated: bool
antialiased: bool | list[bool]
clip_box: BboxBase | None
clip_on: bool
clip_path: mpl.path.Path | Patch | Transform | None
color: ColorType
c: ColorType
dash_capstyple: CapStyleType
dash_joinstyle: JoinStyleType
dashes: LineStyleType
drawstyle: DrawStyleType
ds: DrawStyleType
figure: Figure
fillstyle: FillStyleType
gapcolor: ColorType | None
gid: str
in_layout: bool
label: Any
linestyle: LineStyleType
ls: LineStyleType
marker: MarkerType
markeredgecolor: ColorType
mec: ColorType
markeredgewidth: float
mew: ColorType
markerfacecloralt: ColorType
mfcalt: ColorType
markersize: float
ms: float
markevery: MarkEveryType
mouseover: bool
path_effects: list[AbstractPathEffect]
picker: float | Callable[[Artist, Event], tuple[bool, dict]]
pickradius: float
rasterized: bool
sketch_params: tuple[float, float, float]
snap: bool | None
solid_capstyle: CapStyleType
solid_joinstyle: JoinStyleType
url: str
visible: bool
zorder: float


class PolyCollectionProperty(Line2DProperty, total=False):
array: ArrayLike | None
clim: tuple[float, float]
cmap: Colormap | str | None
edgecolor: ColorType | list[ColorType]
ec: ColorType | list[ColorType]
facecolor: ColorType | list[ColorType]
fc: ColorType | list[ColorType]
hatch: Literal["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*"]
norm: Normalize | str | None
offset_transform: Transform
# offsets: (N, 2) or (2, ) array-likel
sizes: NDArray[np.float_] | None
transform: Transform
urls: list[str] | None


class MPLPlotKwargsBasic(TypedDict, total=False):
"""Kwargs for Axes.plot & Axes.fill_between."""

Expand Down
7 changes: 2 additions & 5 deletions src/arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def is_file_accepted(

return False
try:
_ = cls.find_first_file(int(file))
_ = cls.find_first_file(int(file)) # type: ignore[arg-type]
except ValueError:
return False
return True
Expand Down Expand Up @@ -382,10 +382,7 @@ def postprocess_final(
for k, v in self.MERGE_ATTRS.items():
a_data.attrs.setdefault(k, v)

for a_data in ls:
a_data = _ensure_coords(a_data, self.ENSURE_COORDS_EXIST)

for a_data in ls:
for a_data in [_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in ls]:
if "chi" in a_data.coords and "chi_offset" not in a_data.attrs:
a_data.attrs["chi_offset"] = a_data.coords["chi"].item()

Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/fits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def extract_coords(
scan_coords[name] = np.linspace(start, end, n, endpoint=True)

else:
logger("Loop is tabulated and is region based")
logger.debug("Loop is tabulated and is region based")
name, n = (
attrs[f"NM_{loop}_0"],
attrs[f"NMPOS_{loop}"],
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/nexus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import xarray as xr

__all__ = ["read_data_attributes_from"]
__all__ = ("read_data_attributes_from",)

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/plotting/basic_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self) -> None:
def layout(self) -> QGridLayout:
return self.main_layout

def set_data(self, data: XrTypes) -> None:
def set_data(self, data: xr.DataArray) -> None:
self.data = normalize_to_spectrum(data)

def transpose_to_front(self, dim: Hashable) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/arpes/plotting/stack_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import contextlib
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Literal, Unpack, reveal_type
from typing import TYPE_CHECKING, Literal, Unpack

import matplotlib as mpl
import matplotlib.colorbar
Expand Down Expand Up @@ -69,9 +69,7 @@ def offset_scatter_plot(
data: xr.Dataset,
name_to_plot: str = "",
stack_axis: str = "",
cbarmap: (
tuple[Callable[..., colorbar.Colorbar], Callable[..., Callable[..., ColorType]]] | None
) = None,
cbarmap: tuple[colorbar.Colorbar, Callable[..., ColorType]] | None = None,
ax: Axes | None = None,
out: str | Path = "",
scale_coordinate: float = 0.5,
Expand Down Expand Up @@ -135,6 +133,8 @@ def offset_scatter_plot(
skip_colorbar = True
if cbarmap is None:
skip_colorbar = False
cbar: colorbar.Colorbar | Callable[..., colorbar.Colorbar]
cmap: Colormap
try:
cbar, cmap = colorbarmaps_for_axis[stack_axis]
except KeyError:
Expand Down
45 changes: 25 additions & 20 deletions src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import UTC
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type
from typing import TYPE_CHECKING, Any, Literal, Unpack

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -698,11 +698,11 @@ def imshow_mask(


def imshow_arr(
arr: XrTypes,
arr: xr.DataArray,
ax: Axes | None = None,
over: AxesImage | None = None,
**kwargs: Unpack[IMshowParam],
) -> tuple[Figure, Axes]:
) -> tuple[Figure, AxesImage]:
"""Similar to plt.imshow but users different default origin, and sets appropriate extents.
Args:
Expand All @@ -729,7 +729,7 @@ def imshow_arr(
"extent": (y[0], y[-1], x[0], x[-1]),
}
for k, v in default_kwargs.items():
kwargs.setdefault(str(k), v)
kwargs.setdefault(k, v) # type: ignore[misc]
if over is None:
if kwargs["alpha"] != 1:
if isinstance(kwargs["cmap"], str):
Expand All @@ -755,8 +755,6 @@ def imshow_arr(
kwargs["aspect"] = ax.get_aspect()
quad = ax.imshow(
arr.values,
extent=over.get_extent(),
aspect=ax.get_aspect(),
**kwargs,
)

Expand Down Expand Up @@ -841,7 +839,7 @@ def inset_cut_locator(

n = 200

def resolve(name: str, value: slice | int) -> NDArray[np.float_]:
def resolve(name: Hashable, value: slice | int) -> NDArray[np.float_]:
if isinstance(value, slice):
low = value.start
high = value.stop
Expand Down Expand Up @@ -881,13 +879,13 @@ def resolve(name: str, value: slice | int) -> NDArray[np.float_]:
pass


def generic_colormap(low: float, high: float) -> Callable[..., RGBAColorType]:
def generic_colormap(low: float, high: float) -> Callable[..., ColorType]:
"""Generates a colormap from the cm.Blues palette, suitable for most purposes."""
delta = high - low
low = low - delta / 6
high = high + delta / 6

def get_color(value: float) -> RGBAColorType:
def get_color(value: float) -> ColorType:
return mpl.colormaps.get_cmap("Blues")(
float((value - low) / (high - low)),
)
Expand All @@ -898,10 +896,10 @@ def get_color(value: float) -> RGBAColorType:
def phase_angle_colormap(
low: float = 0,
high: float = np.pi * 2,
) -> Callable[[float], RGBAColorType]:
) -> Callable[[float], ColorType]:
"""Generates a colormap suitable for angular data or data on a unit circle like a phase."""

def get_color(value: float) -> RGBAColorType:
def get_color(value: float) -> ColorType:
return mpl.colormaps.get_cmap("twilight_shifted")(float((value - low) / (high - low)))

return get_color
Expand All @@ -910,10 +908,10 @@ def get_color(value: float) -> RGBAColorType:
def delay_colormap(
low: float = -1,
high: float = 1,
) -> Callable[[float], RGBAColorType]:
) -> Callable[[float], ColorType]:
"""Generates a colormap suitable for pump-probe delay data."""

def get_color(value: float) -> RGBAColorType:
def get_color(value: float) -> ColorType:
return mpl.colormaps.get_cmap("coolwarm")(
float((value - low) / (high - low)),
)
Expand All @@ -925,10 +923,10 @@ def temperature_colormap(
low: float = 0,
high: float = 300,
cmap: Colormap = mpl.colormaps["Blues_r"],
) -> Callable[[float], RGBAColorType]:
) -> Callable[[float], ColorType]:
"""Generates a colormap suitable for temperature data with fixed extent."""

def get_color(value: float) -> RGBAColorType:
def get_color(value: float) -> ColorType:
return cmap(float((value - low) / (high - low)))

return get_color
Expand All @@ -937,10 +935,10 @@ def get_color(value: float) -> RGBAColorType:
def temperature_colormap_around(
central: float,
region: float = 50,
) -> Callable[[float], RGBAColorType]:
) -> Callable[[float], ColorType]:
"""Generates a colormap suitable for temperature data around a central value."""

def get_color(value: float) -> RGBAColorType:
def get_color(value: float) -> ColorType:
return mpl.colormaps.get_cmap("RdBu_r")(float((value - central) / region))

return get_color
Expand Down Expand Up @@ -1103,10 +1101,13 @@ def remove_colorbars(fig: Figure | None = None) -> None:
"""
# TODO: after colorbar removal, plots should be relaxed/rescaled to occupy space previously
# allocated to colorbars for now, can follow this with plt.tight_layout()
COLORBAR_ASPECT_RATIO = 20
try:
if fig is not None:
for ax in fig.axes:
if ax.get_aspect() >= 20: # a bit of a hack
aspect_ragio = ax.get_aspect()
assert isinstance(aspect_ragio, float)
if aspect_ragio >= COLORBAR_ASPECT_RATIO:
ax.remove()
else:
remove_colorbars(plt.gcf())
Expand Down Expand Up @@ -1190,7 +1191,7 @@ def __init__(
size: float = 1,
extent: float = 0.03,
label: str = "",
loc: int = 2,
loc: str = "uppder left",
ax: Axes | None = None,
pad: float = 0.4,
borderpad: float = 0.5,
Expand Down Expand Up @@ -1343,7 +1344,11 @@ def extract(for_data: XrTypes) -> dict[str, Any]:
)

with Path(provenance_path).open("w") as f:
json.dump(provenance_context, f, indent=2)
json.dump(
provenance_context,
f,
indent=2,
)
plt.savefig(full_path, dpi=dpi, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions src/arpes/utilities/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def g(
"""[TODO:summary].
Args:
arr (xr.DataArray): [TODO:description]
arr (xr.DataArray): ARPES Data
*args: Pass to function f
**kwargs: Pass to function f
Returns:
[TODO:description]
xr.DataArray
"""
return xr.DataArray(
arr.values,
Expand Down

0 comments on commit 48f4dbb

Please sign in to comment.