Skip to content

Commit

Permalink
💬 update type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Sep 29, 2023
1 parent ea41530 commit 2c10d58
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
14 changes: 10 additions & 4 deletions arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -290,7 +290,7 @@ def hv_reference_scan(
out: str | Path = "",
e_cut: float = -0.05,
bkg_subtraction: float = 0.8,
**kwargs: Incomplete,
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | None:
"""A reference plot for photon energy scans. Used internally by other code."""
fs = data.S.fat_sel(eV=e_cut)
Expand Down Expand Up @@ -342,11 +342,18 @@ def hv_reference_scan(
return None


class LabeledFermiSurfaceParam(TypedDict, total=False):
out: str | Path
include_symmetry_points: bool
include_bz: bool
fermi_energy: float


@save_plot_provenance
def reference_scan_fermi_surface(
data: DataType,
out: str | Path = "",
**kwargs: Incomplete,
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | None:
"""A reference plot for Fermi surfaces. Used internally by other code."""
fs = data.S.fermi_surface
Expand Down Expand Up @@ -386,7 +393,6 @@ def labeled_fermi_surface(
include_bz: bool = True,
out: str | Path = "",
fermi_energy: float = 0,
**kwargs: Incomplete,
) -> Path | None | tuple[Figure, Axes]:
"""Plots a Fermi surface with high symmetry points annotated onto it."""
fig = None
Expand Down
5 changes: 3 additions & 2 deletions arpes/plotting/fermi_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
if TYPE_CHECKING:
from pathlib import Path

from matplotlib.colors import Normalize
from numpy.typing import NDArray

__all__ = ["fermi_edge_reference", "plot_fit"]
Expand Down Expand Up @@ -109,7 +110,7 @@ def fermi_edge_reference(
title: str = "",
ax: Axes | None = None,
out: str = "",
norm=None,
norm: Normalize | None = None,
) -> Path | None:
"""Fits for and plots results for the Fermi edge on a piece of data.
Expand All @@ -118,7 +119,7 @@ def fermi_edge_reference(
title: A title to attach to the plot
ax: The axes to plot to, if not specified will be generated
out: Where to save the plot
norm ([TODO:type]): [TODO:description]
norm (matplotlib.colors.Normalize): [TODO:description]
Returns:
[TODO:description]
Expand Down
2 changes: 0 additions & 2 deletions arpes/plotting/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
if TYPE_CHECKING:
from pathlib import Path

from _typeshed import Incomplete
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy.typing import NDArray
Expand Down Expand Up @@ -204,7 +203,6 @@ def plot_spatial_reference(
def reference_scan_spatial(
data: DataType,
out: str | Path = "",
**kwargs: Incomplete,
) -> Path | tuple[Figure, NDArray[Axes]]:
"""Plots the spatial content of a dataset, useful as a quick reference."""
data_arr = normalize_to_spectrum(data)
Expand Down
31 changes: 18 additions & 13 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import warnings
from collections import OrderedDict
from logging import DEBUG, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, Unpack

import lmfit
import matplotlib.pyplot as plt
Expand All @@ -61,6 +61,7 @@
from arpes.analysis.general import rebin
from arpes.models.band import MultifitBand
from arpes.plotting.dispersion import (
LabeledFermiSurfaceParam,
fancy_dispersion,
hv_reference_scan,
labeled_fermi_surface,
Expand All @@ -84,6 +85,7 @@
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.typing import RGBColorType
from matplotlibn.colors import Normalize
from numpy.typing import DTypeLike, NDArray

from arpes._typing import ANGLE, SPECTROMETER, DataType
Expand Down Expand Up @@ -1976,10 +1978,11 @@ def plot(
) -> Incomplete:
"""Utility delegate to `xr.DataArray.plot` which rasterizes`.
[TODO:description]
Args:
rasterized: [TODO:description]
rasterized (bool): if True, rasterized (Not vector) drawing
*args: Pass to xr.DataArray.plot
*kwargs: Pass to xr.DataArray.plot
"""
object_is_two_dimensional = 2
if len(self._obj.dims) == object_is_two_dimensional and "rasterized" not in kwargs:
Expand Down Expand Up @@ -2023,12 +2026,10 @@ def fs_plot(
def fermi_edge_reference_plot(
self,
pattern: str = "{}.png",
**kwargs: Incomplete,
**kwargs: str | Normalize | None,
) -> Path | None:
"""Provides a reference plot for a Fermi edge reference.
[TODO:description]
Args:
pattern ([TODO:type]): [TODO:description]
kwargs: pass to plotting.fermi_edge.fermi_edge_reference
Expand All @@ -2048,8 +2049,8 @@ def _referenced_scans_for_spatial_plot(
*,
use_id: bool = True,
pattern: str = "{}.png",
**kwargs: Incomplete,
):
**kwargs: str,
) -> Path | tuple[Figure, NDArray[Axes]]:
"""[TODO:summary].
[TODO:description]
Expand All @@ -2065,15 +2066,15 @@ def _referenced_scans_for_spatial_plot(
out = pattern.format(f"{label}_reference_scan_fs")
kwargs["out"] = out

return plotting.reference_scan_spatial(self._obj, **kwargs)
return plotting.spatial.reference_scan_spatial(self._obj, **kwargs)

def _referenced_scans_for_map_plot(
self,
pattern: str = "{}.png",
*,
use_id: bool = True,
**kwargs: IncompleteMPL,
):
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | None:
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
Expand All @@ -2082,13 +2083,17 @@ def _referenced_scans_for_map_plot(

return reference_scan_fermi_surface(self._obj, **kwargs)

class HvRefScanParam(LabeledFermiSurfaceParam):
e_cut: float
bkg_subtraction: float

def _referenced_scans_for_hv_map_plot(
self,
pattern: str = "{}.png",
*,
use_id: bool = True,
**kwargs: IncompleteMPL,
):
) -> Path | None:
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
Expand Down

0 comments on commit 2c10d58

Please sign in to comment.