Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 12, 2024
1 parent 21b1a6e commit ebcbf53
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 154 deletions.
6 changes: 4 additions & 2 deletions src/arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
MarkerType,
MarkEveryType,
)
from matplotlib.widgets import Button
from numpy.typing import ArrayLike, NDArray
from PySide6 import QtCore
from PySide6.QtGui import QIcon, QPixmap
Expand Down Expand Up @@ -145,7 +146,7 @@ class CURRENTCONTEXT(TypedDict, total=False):
integration_region: dict[Incomplete, Incomplete]
original_data: XrTypes
data: XrTypes
widgets: list[mpl.widgets.AxesWidget]
widgets: list[dict[str, mpl.widgets.AxesWidget] | Button]
points: list[Incomplete]
rect_next: bool
#
Expand Down Expand Up @@ -325,6 +326,8 @@ class DAQINFO(TypedDict, total=False):


class SPECTROMETER(ANALYZERINFO, COORDINATES, DAQINFO, total=False):
name: str
type: str
rad_per_pixel: float
dof: list[str]
scan_dof: list[str]
Expand Down Expand Up @@ -569,7 +572,6 @@ class PLTSubplotParam(TypedDict, total=False):


class AxesImageParam(TypedDict, total=False):
ax: Axes
cmap: str | Colormap
norm: str | Normalize
interpolation: Literal[
Expand Down
12 changes: 8 additions & 4 deletions src/arpes/analysis/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam):
pass


class DecompositionParamBase(TypedDict, total=False):
n_composition: int | None


def decomposition_along(
data: xr.DataArray,
axes: list[str],
Expand Down Expand Up @@ -177,7 +181,7 @@ def decomposition_along(

@wraps(decomposition_along)
def pca_along(
*args: xr.DataArray | list[str],
*args: * tuple[xr.DataArray, list[str]],
**kwargs: Unpack[PCAParam],
) -> tuple[xr.DataArray, sklearn.decomposition.PCA]:
"""Specializes `decomposition_along` with `sklearn.decomposition.PCA`."""
Expand All @@ -188,7 +192,7 @@ def pca_along(

@wraps(decomposition_along)
def factor_analysis_along(
*args: xr.DataArray | list[str],
*args: * tuple[xr.DataArray, list[str]],
**kwargs: Unpack[FactorAnalysisParam],
) -> tuple[xr.DataArray, sklearn.decomposition.FactorAnalysis]:
"""Specializes `decomposition_along` with `sklearn.decomposition.FactorAnalysis`."""
Expand All @@ -197,7 +201,7 @@ def factor_analysis_along(

@wraps(decomposition_along)
def ica_along(
*args: xr.DataArray | list[str],
*args: * tuple[xr.DataArray, list[str]],
**kwargs: Unpack[FastICAParam],
) -> tuple[xr.DataArray, sklearn.decomposition.FastICA]:
"""Specializes `decomposition_along` with `sklearn.decomposition.FastICA`."""
Expand All @@ -206,7 +210,7 @@ def ica_along(

@wraps(decomposition_along)
def nmf_along(
*args: xr.DataArray | list[str],
*args: * tuple[xr.DataArray, list[str]],
**kwargs: Unpack[NMFParam],
) -> tuple[xr.DataArray, sklearn.decomposition.NMF]:
"""Specializes `decomposition_along` with `sklearn.decomposition.NMF`."""
Expand Down
9 changes: 5 additions & 4 deletions src/arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from ._typing import DataType
__all__ = (
"bootstrap",
"estimate_prior_adjustment",
Expand Down Expand Up @@ -131,8 +130,9 @@ def resample(
data: xr.DataArray,
prior_adjustment: float = 1,
) -> xr.DataArray:
rg = np.random.default_rng()
resampled = xr.DataArray(
np.random.Generator.poisson(
rg.poisson(
lam=data.values * prior_adjustment,
size=data.values.shape,
),
Expand All @@ -159,8 +159,9 @@ def resample_true_counts(data: xr.DataArray) -> xr.DataArray:
Returns:
Poisson resampled data.
"""
rg = np.random.default_rng()
resampled = xr.DataArray(
np.random.Generator.poisson(
rg.poisson(
lam=data.values,
size=data.values.shape,
),
Expand All @@ -178,7 +179,7 @@ def resample_true_counts(data: xr.DataArray) -> xr.DataArray:
@update_provenance("Bootstrap true electron counts")
@lift_dataarray_to_generic
def bootstrap_counts(
data: DataType,
data: xr.DataArray,
n_samples: int = 1000,
name: str | None = None,
) -> xr.Dataset:
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/load_pxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, TypeAlias

import numpy as np
import xarray as xr
Expand All @@ -18,7 +18,7 @@
from _typeshed import Incomplete

from ._typing import DataType
Wave = Any # really, igor.Wave but we do not assume installation
Wave: TypeAlias = Any # really, igor.Wave but we do not assume installation

__all__ = (
"read_single_pxt",
Expand Down
41 changes: 17 additions & 24 deletions src/arpes/plotting/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@

from numpy.typing import NDArray

from arpes._typing import EXPERIMENTINFO, DataType, MPLTextParam
from arpes._typing import EXPERIMENTINFO, DataType, MPLTextParam, XrTypes

__all__ = (
"annotate_cuts",
"annotate_point",
"annotate_experimental_conditions",
)

font_scalings = { # see matplotlib.font_manager
"xx-small": 0.579,
"x-small": 0.694,
"small": 0.833,
"medium": 1.0,
"large": 1.200,
"x-large": 1.440,
"xx-large": 1.728,
"larger": 1.2,
"smaller": 0.833,
}


# TODO @<R.Arafune>: Useless: Revision required
# * In order not to use data axis, set transform = ax.Transform
Expand Down Expand Up @@ -80,32 +92,13 @@ def annotate_experimental_conditions(
"large",
"x-large",
"xx-large",
"larger",
"smaller",
]
) = kwargs.get("fontsize", 16)
if isinstance(fontsize_keyword, float):
fontsize = fontsize_keyword
elif fontsize_keyword in (
"xx-small",
"x-small",
"small",
"medium",
"large",
"x-large",
"xx-large",
"smaller",
):
font_scalings = { # see matplotlib.font_manager
"xx-small": 0.579,
"x-small": 0.694,
"small": 0.833,
"medium": 1.0,
"large": 1.200,
"x-large": 1.440,
"xx-large": 1.728,
"larger": 1.2,
"smaller": 0.833,
}
elif fontsize_keyword in font_scalings:
fontsize = mpl.rc_params()["font.size"] * font_scalings[fontsize_keyword]
else:
err_msg = "Incorrect font size setting"
Expand Down Expand Up @@ -162,7 +155,7 @@ def _render_photon(c: dict[str, float]) -> str:

def annotate_cuts(
ax: Axes,
data: DataType,
data: XrTypes,
plotted_axes: NDArray[np.object_],
*,
include_text_labels: bool = False,
Expand All @@ -183,7 +176,7 @@ def annotate_cuts(
from arpes.utilities.conversion.forward import convert_coordinates_to_kspace_forward

converted_coordinates = convert_coordinates_to_kspace_forward(data)
assert converted_coordinates, xr.Dataset | xr.DataArray
assert isinstance(converted_coordinates, xr.Dataset)
assert len(plotted_axes) == TWO_DIMENSION

for k, v in kwargs.items():
Expand Down
5 changes: 2 additions & 3 deletions src/arpes/plotting/false_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
if TYPE_CHECKING:
from pathlib import Path

from _typeshed import Incomplete
from matplotlib.figure import Figure
from numpy.typing import NDArray

Expand All @@ -31,15 +30,15 @@ def false_color_plot(
*,
invert: bool = False,
pmin_pmax: tuple[float, float] = (0, 1),
**kwargs: Incomplete,
figsize: tuple[float, float] = (7, 5),
) -> Path | tuple[Figure | None, Axes]:
"""Plots a spectrum in false color after conversion to R, G, B arrays."""
data_r_arr, data_g_arr, data_b_arr = (normalize_to_spectrum(d) for d in data_rgb)
pmin, pmax = pmin_pmax

fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5)))
fig, ax = plt.subplots(figsize=figsize)
assert isinstance(ax, Axes)

def normalize_channel(channel: NDArray[np.float_]) -> NDArray[np.float_]:
Expand Down
11 changes: 6 additions & 5 deletions src/arpes/plotting/fermi_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
if TYPE_CHECKING:
from pathlib import Path

from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import MPLPlotKwargs
Expand Down Expand Up @@ -117,7 +118,7 @@ def fermi_edge_reference(
"""Fits for and plots results for the Fermi edge on a piece of data.
Args:
data_arr: The data, this should be of type DataArray<lmfit.model.ModelResult>
data: The data, this should be of type DataArray<lmfit.model.ModelResult>
title: A title to attach to the plot
ax: The axes to plot to, if not specified will be generated
out: Where to save the plot
Expand All @@ -133,7 +134,7 @@ def fermi_edge_reference(
assert isinstance(data_arr, xr.DataArray)
sum_dimensions: set[str] = {"cycle", "phi", "kp", "kx"}
sum_dimensions.intersection_update(set(data_arr.dims))
summed_data = data_arr.sum(*list(sum_dimensions))
summed_data = data.sum(*list(sum_dimensions))

broadcast_dimensions = [str(d) for d in summed_data.dims if str(d) != "eV"]
msg = f"Could not product fermi edge reference. Too many dimensions: {broadcast_dimensions}"
Expand All @@ -156,14 +157,14 @@ def fermi_edge_reference(
_, ax = plt.subplots(figsize=(8, 5))

if not title:
title = data_arr.S.label.replace("_", " ")
title = data.S.label.replace("_", " ")

centers.plot(ax=ax, **kwargs)
widths.plot(ax=ax, **kwargs)

if isinstance(ax, Axes):
ax.set_xlabel(label_for_dim(data_arr, ax.get_xlabel()))
ax.set_ylabel(label_for_dim(data_arr, ax.get_ylabel()))
ax.set_xlabel(label_for_dim(data, ax.get_xlabel()))
ax.set_ylabel(label_for_dim(data, ax.get_ylabel()))
ax.set_title(title, font_size=14)

if out:
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/plotting/movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def init() -> tuple[QuadMesh]:
def animate(i: int) -> tuple[QuadMesh]:
coordinate = animation_coords[i]
data_for_plot = data.sel({time_dim: coordinate})
plot.set_array(data_for_plot.values.G.ravel())
plot.set_array(data_for_plot.values.ravel())
return (plot,)

anim = animation.FuncAnimation(
Expand All @@ -104,7 +104,7 @@ def animate(i: int) -> tuple[QuadMesh]:

animation_writer = animation.writers["ffmpeg"]
writer = animation_writer(
fps=1000 / interval_ms,
fps=int(1000 / interval_ms),
metadata={"artist": "Me"},
bitrate=1800,
)
Expand Down
5 changes: 3 additions & 2 deletions src/arpes/plotting/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ def plot_parameter( # noqa: PLR0913
color = kwargs.get("color")
e_width = None
l_width = None
if "fmt" not in kwargs:
kwargs["fmt"] = ""
if two_sigma:
_, _, lines = ax.errorbar(
x + x_shift,
ds.value.values + shift,
yerr=2 * ds.error.values,
fmt="",
elinewidth=1,
linewidth=0,
c=color,
Expand All @@ -64,11 +65,11 @@ def plot_parameter( # noqa: PLR0913
e_width = 2
l_width = 0

kwargs["fmt"] = "s"
ax.errorbar(
x + x_shift,
ds.value.values + shift,
yerr=ds.error.values,
fmt="s",
color=color,
elinewidth=e_width,
linewidth=l_width,
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/utilities/conversion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,11 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[
@update_provenance("Automatically k-space converted")
def convert_to_kspace( # noqa: PLR0913
arr: xr.DataArray,
*,
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
resolution: dict[MOMENTUM, float] | None = None,
calibration: DetectorCalibration | None = None,
coords: dict[MOMENTUM, NDArray[np.float_]] | None = None,
*,
allow_chunks: bool = False,
**kwargs: NDArray[np.float_],
) -> xr.DataArray:
Expand Down
Loading

0 comments on commit ebcbf53

Please sign in to comment.