diff --git a/src/arpes/_typing.py b/src/arpes/_typing.py index 22462e32..552b3e8f 100644 --- a/src/arpes/_typing.py +++ b/src/arpes/_typing.py @@ -36,6 +36,7 @@ from matplotlib.figure import Figure from matplotlib.patches import Patch from matplotlib.patheffects import AbstractPathEffect + from matplotlib.ticker import Locator from matplotlib.transforms import BboxBase, Transform from matplotlib.typing import ( CapStyleType, @@ -450,7 +451,7 @@ class ColorbarParam(TypedDict, total=False): extend: Literal["neither", "both", "min", "max"] extendfrac: None | Literal["auto"] | float | tuple[float, float] | list[float] spacing: Literal["uniform", "proportional"] - ticks: None | list[float] + ticks: None | Sequence[float] | Locator format: str | None drawedges: bool label: str diff --git a/src/arpes/plotting/basic_tools/__init__.py b/src/arpes/plotting/basic_tools/__init__.py index 0e21b536..7db031ee 100644 --- a/src/arpes/plotting/basic_tools/__init__.py +++ b/src/arpes/plotting/basic_tools/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextlib +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING import numpy as np @@ -27,6 +28,19 @@ from arpes._typing import DataType, XrTypes +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + __all__ = ( "path_tool", "mask_tool", @@ -140,6 +154,7 @@ def roi_changed(self, _: Incomplete) -> None: self.path_changed(self.path) def path_changed(self, path: Incomplete) -> None: + logger.debug(f"path: {path}") raise NotImplementedError def add_controls(self) -> None: @@ -215,7 +230,7 @@ def alt_path(self) -> list[Point]: return self.compute_path_from_roi(self.alt_roi) def path_changed(self, path: Incomplete) -> None: - pass + logger.debug(f"path: {path}") @property def calibration(self) -> DetectorCalibration: diff --git a/src/arpes/plotting/stack_plot.py b/src/arpes/plotting/stack_plot.py index ae4656b8..c1bb1109 100644 --- a/src/arpes/plotting/stack_plot.py +++ b/src/arpes/plotting/stack_plot.py @@ -7,7 +7,7 @@ import contextlib from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Literal, Unpack +from typing import TYPE_CHECKING, Literal, Unpack, reveal_type import matplotlib as mpl import matplotlib.colorbar @@ -281,7 +281,7 @@ def flat_stack_plot( # noqa: PLR0913 color = kwargs.pop("color", "viridis") - for i, (_coord_dict, marginal) in enumerate(data_array.G.iterate_axis(stack_axis)): + for i, (_, marginal) in enumerate(data_array.G.iterate_axis(stack_axis)): if mode == "line": ax.plot( horizontal, diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index 9992111e..e4b557dc 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -15,7 +15,7 @@ from datetime import UTC from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Unpack +from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type import matplotlib as mpl import matplotlib.pyplot as plt @@ -155,7 +155,7 @@ def mod_plot_to_ax( ax.plot(xs, ys, **kwargs) -class GradientFillParam(IMshowParam): +class GradientFillParam(IMshowParam, total=False): step: Literal["pre", "mid", "post", None] @@ -434,7 +434,7 @@ def transform_labels( def summarize(data: xr.DataArray, axes: NDArray[np.object_] | None = None) -> NDArray[np.object_]: """Makes a summary plot with different marginal plots represented.""" data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) - axes_shapes_for_dims = { + axes_shapes_for_dims: dict[int, tuple[int, int]] = { 1: (1, 1), 2: (1, 1), 3: (2, 2), # one extra here @@ -442,10 +442,8 @@ def summarize(data: xr.DataArray, axes: NDArray[np.object_] | None = None) -> ND } assert len(data_arr.dims) <= len(axes_shapes_for_dims) if axes is None: - _, axes = plt.subplots( - axes_shapes_for_dims.get(len(data_arr.dims), (3, 2)), - figsize=(8, 8), - ) + n_rows, n_cols = axes_shapes_for_dims.get(len(data_arr.dims), (3, 2)) + _, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(8, 8)) assert isinstance(axes, np.ndarray) flat_axes = axes.ravel() combinations = list(itertools.combinations(data_arr.dims, 2)) @@ -721,7 +719,7 @@ def imshow_arr( assert isinstance(ax, Axes) x, y = arr.coords[arr.dims[0]].values, arr.coords[arr.dims[1]].values - default_kwargs = { + default_kwargs: IMshowParam = { "origin": "lower", "aspect": "auto", "alpha": 1.0, @@ -730,8 +728,8 @@ def imshow_arr( "cmap": "viridis", "extent": (y[0], y[-1], x[0], x[-1]), } - default_kwargs.update(kwargs) - kwargs = default_kwargs + for k, v in default_kwargs.items(): + kwargs.setdefault(str(k), v) if over is None: if kwargs["alpha"] != 1: if isinstance(kwargs["cmap"], str): @@ -997,8 +995,8 @@ def phase_angle_colorbar( def temperature_colorbar( - low: float = 0, - high: float = 300, + low: float = 0.0, + high: float = 300.0, ax: Axes | None = None, **kwargs: Unpack[ColorbarParam], ) -> colorbar.Colorbar: @@ -1143,7 +1141,7 @@ def generic_colorbarmap_for_data( low, high = data.min().item(), data.max().item() ticks = None if keep_ticks: - ticks = data.values + ticks = data.values.tolist() return ( generic_colorbar( low=low, @@ -1219,7 +1217,6 @@ def __init__( size_bar.add_artist(vline2) txt = mpl.offsetbox.TextArea( label, - minimumdescent=False, textprops={ "color": label_color, },