Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 15, 2024
1 parent e553fa0 commit fa7e975
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 36 deletions.
6 changes: 4 additions & 2 deletions src/arpes/plotting/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Unpack

import matplotlib.pyplot as plt
from matplotlib import colorbar
from matplotlib.axes import Axes

from arpes.provenance import save_plot_provenance
Expand Down Expand Up @@ -42,7 +43,6 @@ def plot_with_bands(
out: [TODO:description]
kwargs: pass to data.plot()
Returns:
[TODO:description]
"""
Expand All @@ -54,7 +54,9 @@ def plot_with_bands(
title = data.S.label.replace("_", " ")

mesh: AxesImage = data.plot(ax=ax, **kwargs)
mesh.colorbar.set_label(label_for_colorbar(data))
mesh_colorbar = mesh.colorbar
assert isinstance(mesh_colorbar, colorbar.Colorbar)
mesh_colorbar.set_label(label_for_colorbar(data))

if data.S.is_differentiated:
mesh.set_cmap("Blues")
Expand Down
7 changes: 3 additions & 4 deletions src/arpes/plotting/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,8 @@ def draw(self, renderer) -> None:
def annotate_special_paths(
ax: Axes,
paths: list[str] | str,
cell: Sequence[Sequence[float]] | None = None,
transformations=None,
offset: dict[str, float | Sequence[float]] | None = None,
cell: NDArray[np.float_] | Sequence[Sequence[float]] | None = None,
offset: dict[str, Sequence[float]] | None = None,
special_points: dict[str, NDArray[np.float_]] | None = None,
labels=None,
**kwargs: Incomplete,
Expand All @@ -560,7 +559,7 @@ def annotate_special_paths(
if labels is None:
labels = paths

converted_paths = process_kpath(paths, cell, special_points=special_points)
converted_paths = process_kpath(paths, np.array(cell), special_points=special_points)
logger.debug(f"converted_paths: {converted_paths}")

if not isinstance(labels[0], list):
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
from arpes._typing import PColorMeshKwargs, XrTypes
from arpes.models.band import Band

__all__ = [
__all__ = (
"plot_dispersion",
"labeled_fermi_surface",
"cut_dispersion_plot",
"fancy_dispersion",
"reference_scan_fermi_surface",
"hv_reference_scan",
"scan_var_reference_plot",
]
)


@save_plot_provenance
Expand Down
37 changes: 9 additions & 28 deletions src/arpes/utilities/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import itertools
import re
from collections import Counter
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Literal, NamedTuple

import matplotlib.path
import numpy as np
Expand Down Expand Up @@ -117,7 +117,7 @@ def parse_single_path(path: str) -> list[SpecialPoint]:
negate = True
rest = rest[1:]

bz_coords: tuple[float, float, float] | tuple[float, float] = (
bz_coords: tuple[float, ...] = (
0.0,
0.0,
0.0,
Expand Down Expand Up @@ -213,7 +213,7 @@ def process_kpath(
ToDo: Test
"""
if len(cell) == TWO_DIMENSION:
cell = [[*c, 0] for c in cell] + [[0, 0, 1]]
cell = np.array([[*c, 0] for c in cell] + [[0, 0, 1]])

icell = np.linalg.inv(cell).T

Expand Down Expand Up @@ -302,7 +302,6 @@ def flat_bz_indices_list(
Returns:
[TODO:description]
ToDo: Test
"""
if bz_indices_list is None:
Expand All @@ -311,7 +310,7 @@ def flat_bz_indices_list(
assert len(bz_indices_list[0]) in {2, 3}

indices = []
if len(bz_indices_list[0]) == 2: # noqa: PLR2004
if len(bz_indices_list[0]) == TWO_DIMENSION:
for bz_x, bz_y in bz_indices_list:
rx = range(bz_x, bz_x + 1) if isinstance(bz_x, int) else range(*bz_x)
ry = range(bz_y, bz_y + 1) if isinstance(bz_y, int) else range(*bz_y)
Expand All @@ -335,8 +334,6 @@ def generate_2d_equivalent_points(
) -> NDArray[np.float_]:
"""Generates the equivalent points in higher order Brillouin zones.
[TODO:description]
Args:
points: [TODO:description]
icell: [TODO:description]
Expand Down Expand Up @@ -393,7 +390,7 @@ def build_2dbz_poly(

if vertices is None:
if icell is None:
icell = np.linalg.inv(cell).T
icell = np.linalg.inv(np.array(cell)).T

vertices = bz_vertices(icell)

Expand All @@ -403,11 +400,9 @@ def build_2dbz_poly(
return raw_poly_to_mask(points_2d)


def bz_symmetry(flat_symmetry_points) -> str | None:
def bz_symmetry(flat_symmetry_points) -> Literal["rect", "square", "hex"] | None:
"""Determines symmetry from a list of the symmetry points.
[TODO:description]
Args:
flat_symmetry_points ([TODO:type]): [TODO:description]
Expand All @@ -420,7 +415,7 @@ def bz_symmetry(flat_symmetry_points) -> str | None:
flat_symmetry_points = flat_symmetry_points.items()

largest_identified = 0
symmetry: str | None = None
symmetry: Literal["rect", "square", "hex"] | None = None

point_names = {k for k, _ in flat_symmetry_points}

Expand All @@ -440,8 +435,6 @@ def reduced_bz_axis_to(
) -> NDArray[np.float_]:
"""Calculates a displacement vector to a modded high symmetry point.
[TODO:description]
Args:
data: [TODO:description]
symbol: [TODO:description]
Expand All @@ -456,7 +449,7 @@ def reduced_bz_axis_to(
ToDo: Test
"""
symmetry = bz_symmetry(data.S.iter_own_symmetry_points)
symmetry: Literal["rect", "square", "hex"] = bz_symmetry(data.S.iter_own_symmetry_points)
assert symmetry
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]

Expand Down Expand Up @@ -496,7 +489,7 @@ def reduced_bz_axes(data: XrTypes) -> tuple[NDArray[np.float_], NDArray[np.float
ToDo: Test
"""
symmetry = bz_symmetry(data.S.iter_own_symmetry_points)
symmetry: Literal["rect", "square", "hex"] = bz_symmetry(data.S.iter_own_symmetry_points)
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]

symmetry_points, _ = data.S.symmetry_points()
Expand All @@ -522,8 +515,6 @@ def reduced_bz_axes(data: XrTypes) -> tuple[NDArray[np.float_], NDArray[np.float
def axis_along(data: XrTypes, symbol: str) -> float:
"""Determines which axis lies principally along the direction G->S.
[TODO:description]
Args:
data: [TODO:description]
symbol: [TODO:description]
Expand Down Expand Up @@ -556,8 +547,6 @@ def axis_along(data: XrTypes, symbol: str) -> float:
def reduced_bz_poly(data: XrTypes, *, scale_zone: bool = False) -> NDArray[np.float_]:
"""Returns a polynomial representing the reduce first Brillouin zone.
[TODO:description]
Args:
data: [TODO:description]
scale_zone: [TODO:description]
Expand Down Expand Up @@ -698,8 +687,6 @@ def reduced_bz_mask(data: XrTypes, **kwargs: Incomplete) -> NDArray[np.float_]:
def reduced_bz_selection(data: DataType) -> DataType:
"""Sets data outside the Brillouin zone mask for a piece of data to be nan.
[TODO:description]
Args:
data: [TODO:description]
Expand All @@ -717,8 +704,6 @@ def reduced_bz_selection(data: DataType) -> DataType:
def bz_cutter(symmetry_points, *, reduced: bool = True):
"""Cuts data so that it areas outside the Brillouin zone are masked away.
[TODO:description]
Args:
symmetry_points ([TODO:type]): [TODO:description]
reduced: [TODO:description]
Expand All @@ -729,8 +714,6 @@ def bz_cutter(symmetry_points, *, reduced: bool = True):
def build_bz_mask(data) -> None:
"""[TODO:summary].
[TODO:description]
Args:
data ([TODO:type]): [TODO:description]
Expand All @@ -741,8 +724,6 @@ def build_bz_mask(data) -> None:
def cutter(data, cut_value: float = np.nan):
"""[TODO:summary].
[TODO:description]
Args:
data ([TODO:type]): [TODO:description]
cut_value: [TODO:description]
Expand Down

0 comments on commit fa7e975

Please sign in to comment.