Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 3, 2024
1 parent 8a5f03f commit a8a7eb5
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 113 deletions.
1 change: 1 addition & 0 deletions arpes/analysis/kfermi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tools related to finding the Fermi momentum in a cut."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down
43 changes: 29 additions & 14 deletions arpes/plotting/annotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Annotations onto plots for experimental conditions or locations."""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Unpack
Expand All @@ -19,7 +20,7 @@

from numpy.typing import NDArray

from arpes._typing import DataType, ExperimentalConditions, MPLTextParam
from arpes._typing import EXPERIMENTINFO, DataType, MPLTextParam

__all__ = (
"annotate_cuts",
Expand Down Expand Up @@ -70,7 +71,19 @@ def annotate_experimental_conditions(
delta = 1
current = 0

fontsize_keyword = kwargs.get("fontsize", 16)
fontsize_keyword: (
float
| Literal[
"xx-small",
"x-small",
"small",
"medium",
"large",
"x-large",
"xx-large",
"smaller",
]
) = kwargs.get("fontsize", 16)
if isinstance(fontsize_keyword, float):
fontsize = fontsize_keyword
elif fontsize_keyword in (
Expand All @@ -81,6 +94,7 @@ def annotate_experimental_conditions(
"large",
"x-large",
"xx-large",
"smaller",
):
font_scalings = { # see matplotlib.font_manager
"xx-small": 0.579,
Expand All @@ -93,13 +107,13 @@ def annotate_experimental_conditions(
"larger": 1.2,
"smaller": 0.833,
}
fontsize = mpl.rc_params["font.size"] * font_scalings[fontsize_keyword]
fontsize = mpl.rc_params()["font.size"] * font_scalings[fontsize_keyword]
else:
err_msg = "Incorrect font size setting"
raise RuntimeError(err_msg)
delta = fontsize * delta

conditions: ExperimentalConditions = data.S.experimental_conditions
conditions: EXPERIMENTINFO = data.S.experimental_conditions

renderers = {
"temp": lambda c: "\\textbf{T = " + "{:.3g}".format(c["temp"]) + " K}",
Expand Down Expand Up @@ -153,7 +167,7 @@ def annotate_cuts(
plotted_axes: NDArray[np.object_],
*,
include_text_labels: bool = False,
**kwargs: tuple | list | NDArray,
**kwargs: tuple | list | NDArray[np.float_],
) -> None:
"""Annotates a cut location onto a plot.
Expand All @@ -172,7 +186,7 @@ def annotate_cuts(
assert len(plotted_axes) == TWO_DIMENSION

for k, v in kwargs.items():
selected = converted_coordinates.sel(**dict([[k, v]]), method="nearest")
selected = converted_coordinates.sel({k: v}, method="nearest")

for coords_dict, obj in selected.G.iterate_axis(k):
css = [obj[d].values for d in plotted_axes]
Expand All @@ -193,18 +207,19 @@ def annotate_cuts(
def annotate_point(
ax: Axes | Axes3D,
location: Sequence[float],
label: str,
delta: tuple[float, ...] = (),
**kwargs: Unpack[MPLTextParam],
) -> None:
"""Annotates a point or high symmetry location into a plot."""
label = {
"G": "$\\Gamma$",
"X": r"\textbf{X}",
"Y": r"\textbf{Y}",
"K": r"\textbf{K}",
"M": r"\textbf{M}",
}.get(label, label)
if "label" in kwargs:
label = {
"G": "$\\Gamma$",
"X": r"\textbf{X}",
"Y": r"\textbf{Y}",
"K": r"\textbf{K}",
"M": r"\textbf{M}",
}.get(kwargs["label"], "")
kwargs.pop("label")

if not delta:
delta = (
Expand Down
43 changes: 19 additions & 24 deletions arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d import Axes3D

Expand All @@ -23,7 +24,6 @@
from collections.abc import Sequence
from pathlib import Path

import xarray as xr
from matplotlib.colors import Colormap, Normalize
from matplotlib.figure import Figure, FigureBase
from numpy.typing import NDArray
Expand Down Expand Up @@ -309,14 +309,14 @@ def hv_reference_scan(
e_cut: float = -0.05,
bkg_subtraction: float = 0.8,
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | None:
) -> Path | Axes:
"""A reference plot for photon energy scans. Used internally by other code."""
fs = data.S.fat_sel(eV=e_cut)
fs = normalize_dim(fs, "hv", keep_id=True)
fs.data -= bkg_subtraction * np.mean(fs.data)
fs.data[fs.data < 0] = 0

_, ax = labeled_fermi_surface(fs, hold=True, **kwargs)
_, ax = labeled_fermi_surface(fs, **kwargs)

all_scans = data.attrs["df"]
all_scans = all_scans[all_scans.id != data.attrs["id"]]
Expand Down Expand Up @@ -372,13 +372,13 @@ def reference_scan_fermi_surface(
data: DataType,
out: str | Path = "",
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | None:
) -> Path | Axes:
"""A reference plot for Fermi surfaces. Used internally by other code.
Warning: Not work correctly. (Because S.referenced_scans has been removed.)
"""
fs = data.S.fermi_surface
_, ax = labeled_fermi_surface(fs, hold=True, **kwargs)
_, ax = labeled_fermi_surface(fs, **kwargs)

referenced_scans = data.S.referenced_scans
handles = []
Expand All @@ -399,26 +399,26 @@ def reference_scan_fermi_surface(
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)

plt.show()
return None
return ax


@save_plot_provenance
def labeled_fermi_surface( # noqa: PLR0913
data: DataType,
data: xr.DataArray,
title: str = "",
ax: Axes | None = None,
*,
hold: bool = False,
include_symmetry_points: bool = True,
include_bz: bool = True,
out: str | Path = "",
fermi_energy: float = 0,
) -> Path | None | tuple[Figure | None, Axes]:
) -> Path | tuple[Figure | None, Axes]:
"""Plots a Fermi surface with high symmetry points annotated onto it."""
assert isinstance(data, xr.DataArray)
fig = None
if ax is None:
fig, ax = plt.subplots(figsize=(7, 7))
assert isinstance(ax, Axes)

if not title:
title = "{} Fermi Surface".format(data.S.label.replace("_", " "))
Expand All @@ -434,7 +434,6 @@ def labeled_fermi_surface( # noqa: PLR0913

dim_order = [ax.get_xlabel(), ax.get_ylabel()]

ax.dim_order = dim_order
ax.set_xlabel(label_for_dim(data, ax.get_xlabel()))
ax.set_ylabel(label_for_dim(data, ax.get_ylabel()))
ax.set_title(title)
Expand Down Expand Up @@ -467,15 +466,12 @@ def labeled_fermi_surface( # noqa: PLR0913
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)

if not hold:
plt.show()
return None
return fig, ax


@save_plot_provenance
def fancy_dispersion(
data: DataType,
data: xr.DataArray,
title: str = "",
ax: Axes | None = None,
out: str | Path = "",
Expand All @@ -488,10 +484,10 @@ def fancy_dispersion(
Useful for brief slides/quick presentations.
Args:
data: [TODO:description]
title: [TODO:description]
ax: [TODO:description]
out: [TODO:description]
data (xr.DataArray): ARPES data.
title (str): Title of Figure.
ax (Axes): matpplotlib Axes object
out (str | Path): str or Path object for output image.
include_symmetry_points: [TODO:description]
kwargs: pass to xr.Dataset.plot or xr.DataArray.plot()
Expand Down Expand Up @@ -543,18 +539,17 @@ def fancy_dispersion(
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)

plt.show()
return ax


@save_plot_provenance
def scan_var_reference_plot(
data: DataType,
data: xr.DataArray,
title: str = "",
ax: Axes | None = None,
norm: Normalize | None = None,
out: str | Path = "",
) -> None | Path:
) -> Axes | Path:
"""Makes a straightforward plot of a DataArray with reasonable axes.
Used internally by other scripts.
Expand All @@ -569,6 +564,7 @@ def scan_var_reference_plot(
Returns:
[TODO:description]
"""
assert isinstance(data, xr.DataArray)
if ax is None:
_, ax = plt.subplots(figsize=(8, 5))
assert isinstance(ax, Axes)
Expand All @@ -587,5 +583,4 @@ def scan_var_reference_plot(
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)

plt.show()
return None
return ax
8 changes: 4 additions & 4 deletions arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import re
import warnings
from collections import Counter
from collections.abc import Generator, Iterable, Iterator, Sequence
from collections.abc import Generator, Hashable, Iterable, Iterator, Sequence
from datetime import UTC
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def label_for_colorbar(data: DataType) -> str:

def label_for_dim(
data: DataType | None = None,
dim_name: str = "",
dim_name: Hashable = "",
*,
escaped: bool = True,
) -> str:
Expand Down Expand Up @@ -1594,7 +1594,7 @@ def label_for_dim(
else:
raw_dim_names["eV"] = "Binding Energy ( eV )"
if dim_name in raw_dim_names:
label_dim_name = raw_dim_names.get(dim_name, "")
label_dim_name = raw_dim_names.get(str(dim_name), "")
if not escaped:
label_dim_name = label_dim_name.replace("$", "")
return label_dim_name
Expand All @@ -1618,7 +1618,7 @@ def titlecase(s: str) -> str:
"""
return s.title()

return titlecase(dim_name.replace("_", " "))
return titlecase(str(dim_name).replace("_", " "))


def fancy_labels(
Expand Down
6 changes: 2 additions & 4 deletions arpes/utilities/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def as_2d(points_3d: ArrayLike) -> NDArray[np.float_]:
def parse_single_path(path: str) -> list[SpecialPoint]:
"""Converts a path given by high symmetry point names to numerical coordinate arrays.
[TODO:description]
Args:
path: [TODO:description]
Expand Down Expand Up @@ -135,7 +133,7 @@ def parse_single_path(path: str) -> list[SpecialPoint]:
return points


def parse_path(paths: str | list[str]) -> list[list[SpecialPoint]]:
def _parse_path(paths: str | list[str]) -> list[list[SpecialPoint]]:
"""Converts paths to arrays with the coordinate locations for those paths.
Args:
Expand Down Expand Up @@ -227,7 +225,7 @@ def process_kpath(

return [
[special_point_to_vector(elem, icell, special_points) for elem in p]
for p in parse_path(paths)
for p in _parse_path(paths)
]


Expand Down
13 changes: 13 additions & 0 deletions arpes/utilities/combine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Syntax suger for combination of ARPES data."""

from __future__ import annotations

import xarray as xr

from arpes.provenance import provenance_multiple_parents

__all__ = ("concat_along_phi",)


Expand Down Expand Up @@ -61,6 +64,16 @@ def concat_along_phi(
combine_attrs="drop_conflicts",
).sortby("phi")
concat_array.attrs["id"] = id_add
provenance_multiple_parents(
concat_array,
[arr_a, arr_b],
{
"what": "concat_along_phi",
"parant_id": (id_arr_a, id_arr_b),
"occupation_ratio": occupation_ratio,
},
keep_parent_ref=True,
)
return concat_array


Expand Down
7 changes: 2 additions & 5 deletions arpes/utilities/conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from __future__ import annotations

from collections.abc import Hashable
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Hashable

from _typeshed import Incomplete
from numpy.typing import NDArray
Expand Down Expand Up @@ -86,7 +85,6 @@ def prep(self, arr: xr.DataArray) -> None:
cache computations as they arrive. This is the technique that is used in
ConvertKxKy below
"""
...
assert isinstance(arr, xr.DataArray)

@property
Expand Down Expand Up @@ -121,9 +119,8 @@ def kspace_to_BE(
logger.debug(msg)
return binding_energy

def conversion_for(self, dim: str) -> Callable:
def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]:
"""Fetches the method responsible for calculating `dim` from momentum coordinates."""
...
assert isinstance(dim, str)

def identity_transform(self, axis_name: str, *args: Incomplete) -> NDArray[np.float_]:
Expand Down
Loading

0 comments on commit a8a7eb5

Please sign in to comment.