Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 15, 2024
1 parent 5a01905 commit c245498
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from matplotlib.figure import Figure
from matplotlib.patches import Patch
from matplotlib.patheffects import AbstractPathEffect
from matplotlib.ticker import Locator
from matplotlib.transforms import BboxBase, Transform
from matplotlib.typing import (
CapStyleType,
Expand Down Expand Up @@ -450,7 +451,7 @@ class ColorbarParam(TypedDict, total=False):
extend: Literal["neither", "both", "min", "max"]
extendfrac: None | Literal["auto"] | float | tuple[float, float] | list[float]
spacing: Literal["uniform", "proportional"]
ticks: None | list[float]
ticks: None | Sequence[float] | Locator
format: str | None
drawedges: bool
label: str
Expand Down
17 changes: 16 additions & 1 deletion src/arpes/plotting/basic_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import contextlib
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -27,6 +28,19 @@

from arpes._typing import DataType, XrTypes

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


__all__ = (
"path_tool",
"mask_tool",
Expand Down Expand Up @@ -140,6 +154,7 @@ def roi_changed(self, _: Incomplete) -> None:
self.path_changed(self.path)

def path_changed(self, path: Incomplete) -> None:
logger.debug(f"path: {path}")
raise NotImplementedError

def add_controls(self) -> None:
Expand Down Expand Up @@ -215,7 +230,7 @@ def alt_path(self) -> list[Point]:
return self.compute_path_from_roi(self.alt_roi)

def path_changed(self, path: Incomplete) -> None:
pass
logger.debug(f"path: {path}")

@property
def calibration(self) -> DetectorCalibration:
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/plotting/stack_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import contextlib
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Literal, Unpack
from typing import TYPE_CHECKING, Literal, Unpack, reveal_type

import matplotlib as mpl
import matplotlib.colorbar
Expand Down Expand Up @@ -281,7 +281,7 @@ def flat_stack_plot( # noqa: PLR0913

color = kwargs.pop("color", "viridis")

for i, (_coord_dict, marginal) in enumerate(data_array.G.iterate_axis(stack_axis)):
for i, (_, marginal) in enumerate(data_array.G.iterate_axis(stack_axis)):
if mode == "line":
ax.plot(
horizontal,
Expand Down
25 changes: 11 additions & 14 deletions src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import UTC
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Unpack
from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -155,7 +155,7 @@ def mod_plot_to_ax(
ax.plot(xs, ys, **kwargs)


class GradientFillParam(IMshowParam):
class GradientFillParam(IMshowParam, total=False):
step: Literal["pre", "mid", "post", None]


Expand Down Expand Up @@ -434,18 +434,16 @@ def transform_labels(
def summarize(data: xr.DataArray, axes: NDArray[np.object_] | None = None) -> NDArray[np.object_]:
"""Makes a summary plot with different marginal plots represented."""
data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
axes_shapes_for_dims = {
axes_shapes_for_dims: dict[int, tuple[int, int]] = {
1: (1, 1),
2: (1, 1),
3: (2, 2), # one extra here
4: (3, 2), # corresponds to 4 choose 2 axes
}
assert len(data_arr.dims) <= len(axes_shapes_for_dims)
if axes is None:
_, axes = plt.subplots(
axes_shapes_for_dims.get(len(data_arr.dims), (3, 2)),
figsize=(8, 8),
)
n_rows, n_cols = axes_shapes_for_dims.get(len(data_arr.dims), (3, 2))
_, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(8, 8))
assert isinstance(axes, np.ndarray)
flat_axes = axes.ravel()
combinations = list(itertools.combinations(data_arr.dims, 2))
Expand Down Expand Up @@ -721,7 +719,7 @@ def imshow_arr(
assert isinstance(ax, Axes)

x, y = arr.coords[arr.dims[0]].values, arr.coords[arr.dims[1]].values
default_kwargs = {
default_kwargs: IMshowParam = {
"origin": "lower",
"aspect": "auto",
"alpha": 1.0,
Expand All @@ -730,8 +728,8 @@ def imshow_arr(
"cmap": "viridis",
"extent": (y[0], y[-1], x[0], x[-1]),
}
default_kwargs.update(kwargs)
kwargs = default_kwargs
for k, v in default_kwargs.items():
kwargs.setdefault(str(k), v)
if over is None:
if kwargs["alpha"] != 1:
if isinstance(kwargs["cmap"], str):
Expand Down Expand Up @@ -997,8 +995,8 @@ def phase_angle_colorbar(


def temperature_colorbar(
low: float = 0,
high: float = 300,
low: float = 0.0,
high: float = 300.0,
ax: Axes | None = None,
**kwargs: Unpack[ColorbarParam],
) -> colorbar.Colorbar:
Expand Down Expand Up @@ -1143,7 +1141,7 @@ def generic_colorbarmap_for_data(
low, high = data.min().item(), data.max().item()
ticks = None
if keep_ticks:
ticks = data.values
ticks = data.values.tolist()
return (
generic_colorbar(
low=low,
Expand Down Expand Up @@ -1219,7 +1217,6 @@ def __init__(
size_bar.add_artist(vline2)
txt = mpl.offsetbox.TextArea(
label,
minimumdescent=False,
textprops={
"color": label_color,
},
Expand Down

0 comments on commit c245498

Please sign in to comment.