Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 21, 2024
1 parent ef24869 commit 53009ee
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 89 deletions.
1 change: 0 additions & 1 deletion resources/example_configuration/local_config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@
"marginal_width": 300,
"palette": "magma",
},
"xarray_repr_mod": True,
"DEBUG": True,
}
7 changes: 3 additions & 4 deletions src/arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class QPushButtonARGS(TypedDict, total=False):
#


class MPLPlotKwagsBasic(TypedDict, total=False):
class MPLPlotKwargsBasic(TypedDict, total=False):
"""Kwargs for Axes.plot & Axes.fill_between."""

agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]]
Expand Down Expand Up @@ -402,10 +402,10 @@ class MPLPlotKwagsBasic(TypedDict, total=False):
visible: bool


class MPLPlotKwargs(MPLPlotKwagsBasic, total=False):
class MPLPlotKwargs(MPLPlotKwargsBasic, total=False):
scalex: bool
scaley: bool

fmt: str
dash_capstyle: CapStyleType
dash_joinstyle: JoinStyleType
dashes: Sequence[float | None]
Expand Down Expand Up @@ -440,7 +440,6 @@ class ColorbarParam(TypedDict, total=False):
alpha: float
orientation: None | Literal["vertical", "horizontal"]
ticklocation: Literal["auto", "right", "top", "bottom"]
drawedge: bool
extend: Literal["neither", "both", "min", "max"]
extendfrac: None | Literal["auto"] | float | tuple[float, float] | list[float]
spacing: Literal["uniform", "proportional"]
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def resolve_partial_bands_from_description(
"band": band,
"name": f"{name}_{i}",
"params": _build_params(
old_params=params,
params=params,
center=band_center,
center_stray=params.get("stray", stray),
marginal=marginal,
Expand Down
8 changes: 7 additions & 1 deletion src/arpes/analysis/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ def make_psf(
sigmas: dict[Hashable, float],
*,
fwhm: bool = True,
clip: float | None = None,
) -> xr.DataArray:
"""Produces an n-dimensional gaussian point spread function for use in deconvolve_rl.
Args:
data (DataType): input data
sigmas (dict[str, float]): sigma values for each dimension.
fwhm (bool): if True, sigma is FWHM, not the standard deviation.
clip (float | bool): clip the region by sigma-unit.
Returns:
The PSF to use.
Expand Down Expand Up @@ -184,11 +186,15 @@ def make_psf(

coords_for_pdf_pos = np.stack(coords, axis=-1) # point distribution function (pdf)
logger.debug(f"shape of coords_for_pdf_pos: {coords_for_pdf_pos.shape}")
return xr.DataArray(
psf = xr.DataArray(
multivariate_normal(mean=np.zeros(len(sigmas)), cov=cov).pdf(
coords_for_pdf_pos,
),
dims=data.dims,
coords=psf_coords,
name="PSF",
)
if clip:
clipping_region = {k: slice(-clip * v, clip * v) for k, v in sigmas.items()}
return psf.sel(clipping_region)
return psf
4 changes: 3 additions & 1 deletion src/arpes/analysis/savitzky_golay.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ def savitzky_golay_array(
half_window = (window_size - 1) // 2
# precompute coefficients
b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window + 1)])
m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
b_inv = np.linalg.pinv(b)
assert isinstance(b_inv, np.matrix)
m = b_inv.A[deriv] * rate**deriv * factorial(deriv)
# pad the signal at the extremes with
# values taken from the signal itself
firstvals = y[0] - np.abs(y[1 : half_window + 1][::-1] - y[0])
Expand Down
1 change: 0 additions & 1 deletion src/arpes/analysis/xps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def local_minima(a: NDArray[np.float_], promenance: int = 3) -> NDArray[np.float
Returns:
A mask where the local minima are True and other values are False.
"""
conditions = a == a
for i in range(1, promenance + 1):
current_conditions = np.r_[[False] * i, a[i:] < a[:-i]] & np.r_[a[:-i] < a[i:], [False] * i]
conditions = conditions & current_conditions
Expand Down
24 changes: 12 additions & 12 deletions src/arpes/plotting/fermi_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Unpack

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -19,9 +19,10 @@
if TYPE_CHECKING:
from pathlib import Path

from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import MPLPlotKwargs

__all__ = ["fermi_edge_reference", "plot_fit"]


Expand Down Expand Up @@ -107,16 +108,16 @@ def plot_fit(

@save_plot_provenance
def fermi_edge_reference(
data: xr.DataArray,
data_arr: xr.DataArray,
title: str = "",
ax: Axes | None = None,
out: str | Path = "",
**kwargs: Incomplete,
**kwargs: Unpack[MPLPlotKwargs],
) -> Path | Axes:
"""Fits for and plots results for the Fermi edge on a piece of data.
Args:
data: The data, this should be of type DataArray<lmfit.model.ModelResult>
data_arr: The data, this should be of type DataArray<lmfit.model.ModelResult>
title: A title to attach to the plot
ax: The axes to plot to, if not specified will be generated
out: Where to save the plot
Expand All @@ -129,10 +130,10 @@ def fermi_edge_reference(
"Not automatically correcting for slit shape distortions to the Fermi edge",
stacklevel=2,
)
assert isinstance(data, xr.DataArray)
assert isinstance(data_arr, xr.DataArray)
sum_dimensions: set[str] = {"cycle", "phi", "kp", "kx"}
sum_dimensions.intersection_update(set(data.dims))
summed_data = data.sum(*list(sum_dimensions))
sum_dimensions.intersection_update(set(data_arr.dims))
summed_data = data_arr.sum(*list(sum_dimensions))

broadcast_dimensions = [str(d) for d in summed_data.dims if str(d) != "eV"]
msg = f"Could not product fermi edge reference. Too many dimensions: {broadcast_dimensions}"
Expand All @@ -155,18 +156,17 @@ def fermi_edge_reference(
_, ax = plt.subplots(figsize=(8, 5))

if not title:
title = data.S.label.replace("_", " ")
title = data_arr.S.label.replace("_", " ")

centers.plot(ax=ax, **kwargs)
widths.plot(ax=ax, **kwargs)

if isinstance(ax, Axes):
ax.set_xlabel(label_for_dim(data, ax.get_xlabel()))
ax.set_ylabel(label_for_dim(data, ax.get_ylabel()))
ax.set_xlabel(label_for_dim(data_arr, ax.get_xlabel()))
ax.set_ylabel(label_for_dim(data_arr, ax.get_ylabel()))
ax.set_title(title, font_size=14)

if out:
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)

return ax
12 changes: 6 additions & 6 deletions src/arpes/plotting/stack_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
from pathlib import Path

from matplotlib.figure import Figure
from matplotlib.typing import ColorType, RGBAColorType
from matplotlib.typing import ColorType, RGBAColorType, RGBColorType
from numpy.typing import NDArray

from arpes._typing import LegendLocation, MPLPlotKwagsBasic, XrTypes
from arpes._typing import LegendLocation, MPLPlotKwargsBasic
__all__ = (
"stack_dispersion_plot",
"flat_stack_plot",
Expand Down Expand Up @@ -217,7 +217,7 @@ def flat_stack_plot( # noqa: PLR0913
title: str = "",
out: str | Path = "",
loc: LegendLocation = "upper left",
**kwargs: Unpack[MPLPlotKwagsBasic],
**kwargs: Unpack[MPLPlotKwargsBasic],
) -> Path | tuple[Figure | None, Axes]:
"""Generates a stack plot with all the lines distinguished by color rather than offset.
Expand Down Expand Up @@ -313,7 +313,7 @@ def flat_stack_plot( # noqa: PLR0913

@save_plot_provenance
def stack_dispersion_plot( # noqa: PLR0913
data: XrTypes,
data: xr.DataArray,
*,
stack_axis: str = "",
ax: Axes | None = None,
Expand All @@ -326,7 +326,7 @@ def stack_dispersion_plot( # noqa: PLR0913
negate: bool = False,
figsize: tuple[float, float] = (7, 7),
title: str = "",
**kwargs: Unpack[MPLPlotKwagsBasic],
**kwargs: Unpack[MPLPlotKwargsBasic],
) -> Path | tuple[Figure | None, Axes]:
"""Generates a stack plot with all the lines distinguished by offset (and color).
Expand Down Expand Up @@ -573,7 +573,7 @@ def _color_for_plot(
color: Colormap | ColorType,
i: int,
num_plot: int,
) -> RGBAColorType:
) -> RGBAColorType | RGBColorType:
if isinstance(color, Colormap):
cmap = color
return cmap(np.abs(i / num_plot))
Expand Down
33 changes: 20 additions & 13 deletions src/arpes/plotting/tof.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Unpack

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -27,7 +27,7 @@
import xarray as xr
from matplotlib.figure import Figure

from arpes._typing import DataType
from arpes._typing import MPLPlotKwargs, MPLPlotKwargsBasic

__all__ = (
"plot_with_std",
Expand All @@ -37,37 +37,41 @@

@save_plot_provenance
def plot_with_std(
data: DataType,
data_set: xr.Dataset, # dat_vars is used,
name_to_plot: str = "",
ax: Axes | None = None,
out: str | Path = "",
**kwargs: tuple[int, int] | float | str,
figsize: tuple[float, float] = (7, 5),
**kwargs: Unpack[MPLPlotKwargs],
) -> Path | tuple[Figure | None, Axes]:
"""Makes a fill-between line plot with error bars from associated statistical errors.
Args:
data(xr.Dataset): ARPES data that 'mean_and_deviation' is applied.
data_set (xr.Dataset): ARPES data that 'mean_and_deviation' is applied.
name_to_plot(str): data name to plot, in most case "spectrum" is used.
ax: Matplotlib Axes object
out: (str | Path): Path name to output figure.
figsize (tuple[float, float]): figure size
**kwargs: pass to subplots if figsize is set as tuple, other kwargs are pass to
ax.fill_between/xr.DataArray.plot
"""
if not name_to_plot:
var_names = [k for k in data.data_vars if "_std" not in str(k)]
var_names = [k for k in data_set.data_vars if "_std" not in str(k)]
assert len(var_names) == 1
name_to_plot = str(var_names[0])
assert (name_to_plot + "_std") in data.data_vars, "Has 'mean_and_deviation' been applied?"
assert (
name_to_plot + "_std"
) in data_set.data_vars, "Has 'mean_and_deviation' been applied?"

fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5)))
fig, ax = plt.subplots(figsize=figsize)
assert isinstance(ax, Axes)

data.data_vars[name_to_plot].plot(ax=ax, **kwargs)
x, y = data.data_vars[name_to_plot].G.to_arrays()
data_set.data_vars[name_to_plot].plot(ax=ax, **kwargs)
x, y = data_set.data_vars[name_to_plot].G.to_arrays()

std = data.data_vars[name_to_plot + "_std"].values
std = data_set.data_vars[name_to_plot + "_std"].values
ax.fill_between(x, y - std, y + std, alpha=0.3, **kwargs)

if out:
Expand All @@ -85,7 +89,8 @@ def scatter_with_std(
name_to_plot: str = "",
ax: Axes | None = None,
out: str | Path = "",
**kwargs: tuple[int, int] | float | str,
figsize: tuple[float, float] = (7, 5),
**kwargs: Unpack[MPLPlotKwargsBasic],
) -> Path | tuple[Figure | None, Axes]:
"""Makes a scatter plot of data with error bars generated from associated statistical errors.
Expand All @@ -94,6 +99,8 @@ def scatter_with_std(
name_to_plot(str): data name to plot, in most case "spectrum" is used.
ax: Matplotlib Axes object
out: (str | Path): Path name to output figure.
figsize (tuple[float, float]): tuple for figure size.
fmt (str): THe form at for the data points/lines.
**kwargs: pass to subplots if figsize is set as tuple, other kwargs are pass to ax.errorbar
"""
if not name_to_plot:
Expand All @@ -106,7 +113,7 @@ def scatter_with_std(

fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5)))
fig, ax = plt.subplots(figsize=figsize)
assert isinstance(ax, Axes)
x, y = data.data_vars[name_to_plot].G.to_arrays()

Expand Down
8 changes: 4 additions & 4 deletions src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def temperature_colorbar(
low: float = 0,
high: float = 300,
ax: Axes | None = None,
**kwargs: Incomplete,
**kwargs: Unpack[ColorbarParam],
) -> colorbar.Colorbar:
"""Generates a colorbar suitable for temperature data with fixed extent."""
assert isinstance(ax, Axes)
Expand All @@ -1017,7 +1017,7 @@ def delay_colorbar(
low: float = -1,
high: float = 1,
ax: Axes | None = None,
**kwargs: Incomplete,
**kwargs: Unpack[ColorbarParam],
) -> colorbar.Colorbar:
assert isinstance(ax, Axes)
"""Generates a colorbar suitable for delay data.
Expand All @@ -1037,7 +1037,7 @@ def temperature_colorbar_around(
central: float,
temperature_range: float = 50,
ax: Axes | None = None,
**kwargs: Incomplete,
**kwargs: Unpack[ColorbarParam],
) -> colorbar.Colorbar:
"""Generates a colorbar suitable for temperature axes around a central value."""
assert isinstance(ax, Axes)
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def generic_colorbarmap_for_data(
ax: Axes,
*,
keep_ticks: bool = True,
**kwargs: Incomplete,
**kwargs: Unpack[ColorbarParam],
) -> tuple[colorbar.Colorbar, Callable[..., RGBAColorType]]:
"""Generates a colorbar and colormap which is useful in general context.
Expand Down
2 changes: 2 additions & 0 deletions src/arpes/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class PROVENANCE(TypedDict, total=False):
#
old_axis: str
new_axis: str
#
occupation_ratio: float


def attach_id(data: XrTypes) -> None:
Expand Down
Loading

0 comments on commit 53009ee

Please sign in to comment.