Skip to content

Commit

Permalink
🔨 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Oct 18, 2023
1 parent 59eb8fb commit 36ad17f
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 31 deletions.
2 changes: 1 addition & 1 deletion arpes/endstations/plugin/Elettra_spectromicroscopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def collect_coord(index: int, dset: h5py.Dataset) -> tuple[str, NDArray[np.float
Args:
index: The index of the coordinate to extract from metadata.
dset: The HDF dataset containin Elettra spectromicroscopy data.
dset: The HDF dataset containing Elettra spectromicroscopy data.
Returns:
The coordinate extracted at `index` from the metadata. The return convention here is to
Expand Down
17 changes: 16 additions & 1 deletion arpes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -37,6 +38,19 @@
)


LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[0]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


def load_data(
file: str | Path | int,
location: str | type | None = None,
Expand Down Expand Up @@ -84,7 +98,8 @@ def load_data(
),
stacklevel=2,
)

if kwargs.get("trace"):
logger.debug(f"contents of desc: {desc}")
return load_scan(desc, **kwargs)


Expand Down
59 changes: 37 additions & 22 deletions arpes/utilities/conversion/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from __future__ import annotations

import warnings
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from arpes.analysis.filters import gaussian_filter_arr
from arpes.provenance import update_provenance
from arpes.trace import traceable
from arpes.trace import Trace, traceable
from arpes.utilities import normalize_to_spectrum

from .bounds_calculations import (
Expand All @@ -31,7 +32,7 @@
from .core import convert_to_kspace

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Sequence

from numpy.typing import NDArray

Expand All @@ -46,12 +47,25 @@
)


LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


@traceable
def convert_coordinate_forward(
data: DataType,
coords: dict[str, float],
*,
trace: Callable = None, # noqa: RUF013
trace: Trace | None = None,
**k_coords: NDArray[np.float_],
) -> dict[str, float]:
"""Inverse/forward transform for the small angle volumetric k-conversion code.
Expand Down Expand Up @@ -83,8 +97,8 @@ def convert_coordinate_forward(
Another approach would be to write down the exact small angle approximated transforms.
Args:
data: The data defining the coordinate offsets and experiment geometry.
coords: The coordinates of a point in angle-space to be converted.
data (DataType): The data defining the coordinate offsets and experiment geometry.
coords (dict[str, float]): The coordinates of a *point* in angle-space to be converted.
trace: Used for performance tracing and debugging.
k_coords: Coordinate for k-axis
Expand All @@ -110,20 +124,20 @@ def convert_coordinate_forward(
"ky": np.linspace(-4, 4, 300),
}
# Copying after taking a constant energy plane is much much cheaper
trace("Copying")
trace("Copying") if trace else None
data_arr = data_arr.copy(deep=True)

data_arr.loc[data_arr.G.round_coordinates(coords)] = data_arr.values.max() * 100000
trace("Filtering")
trace("Filtering") if trace else None
data_arr = gaussian_filter_arr(data_arr, default_size=3)

trace("Converting once")
trace("Converting once") if trace else None
kdata = convert_to_kspace(data_arr, **k_coords, trace=trace)

trace("argmax")
trace("argmax") if trace else None
near_target = kdata.G.argmax_coords()

trace("Converting twice")
trace("Converting twice") if trace else None
kdata_close = convert_to_kspace(
data_arr,
trace=trace,
Expand All @@ -132,7 +146,7 @@ def convert_coordinate_forward(

# inconsistently, the energy coordinate is sometimes returned here
# so we remove it just in case
trace("argmax")
trace("argmax") if trace else None
coords = kdata_close.G.argmax_coords()
if "eV" in coords:
del coords["eV"]
Expand All @@ -148,7 +162,7 @@ def convert_through_angular_pair( # noqa: PLR0913
transverse_specification: dict[str, NDArray[np.float_]],
*,
relative_coords: bool = True,
trace: Callable = None, # noqa: RUF013
trace: Trace | None = None,
**k_coords: NDArray[np.float_],
) -> dict[str, float]:
"""Converts the lower dimensional ARPES cut passing through `first_point` and `second_point`.
Expand Down Expand Up @@ -203,12 +217,12 @@ def convert_through_angular_pair( # noqa: PLR0913
k_second_point["ky"] - k_first_point["ky"],
k_second_point["kx"] - k_first_point["kx"],
)
trace(f"Determined offset angle {-offset_ang}")
trace(f"Determined offset angle {-offset_ang}") if trace else None

with data.S.with_rotation_offset(-offset_ang):
trace("Finding first momentum coordinate.")
trace("Finding first momentum coordinate.") if trace else None
k_first_point = convert_coordinate_forward(data, first_point, trace=trace, **k_coords)
trace("Finding second momentum coordinate.")
trace("Finding second momentum coordinate.") if trace else None
k_second_point = convert_coordinate_forward(data, second_point, trace=trace, **k_coords)

# adjust output coordinate ranges
Expand All @@ -229,15 +243,15 @@ def convert_through_angular_pair( # noqa: PLR0913
parallel_axis = np.linspace(left_point, right_point, len(parallel_axis))

# perform the conversion
trace("Performing final momentum conversion.")
trace("Performing final momentum conversion.") if trace else None
converted_data = convert_to_kspace(
data,
**transverse_specification,
kx=parallel_axis,
trace=trace,
).mean(list(transverse_specification.keys()))

trace("Annotating the requested point momentum values.")
trace("Annotating the requested point momentum values.") if trace else None
return converted_data.assign_attrs(
{
"first_point_kx": k_first_point[parallel_dim],
Expand All @@ -255,7 +269,7 @@ def convert_through_angular_point( # noqa: PLR0913
transverse_specification: dict[str, NDArray[np.float_]],
*,
relative_coords: bool = True,
trace: Callable = None, # noqa: RUF013
trace: Trace | None = None,
**k_coords: NDArray[np.float_],
) -> xr.DataArray:
"""Converts the lower dimensional ARPES cut passing through given angular `coords`.
Expand Down Expand Up @@ -311,7 +325,7 @@ def convert_coordinates(
) -> xr.Dataset:
"""Converts coordinates forward in momentum."""

def unwrap_coord(coord):
def unwrap_coord(coord: xr.DataArray | NDArray[np.float_]) -> NDArray[np.float_]:
try:
return coord.values
except (TypeError, AttributeError):
Expand Down Expand Up @@ -389,7 +403,7 @@ def expand_to(cname: str, c: Sequence[float]) -> float:


@update_provenance("Forward convert coordinates to momentum")
def convert_coordinates_to_kspace_forward(arr: DataType) -> xr.Dataset | None:
def convert_coordinates_to_kspace_forward(arr: DataType) -> xr.Dataset:
"""Forward converts all the individual coordinates of the data array.
Args:
Expand All @@ -406,7 +420,8 @@ def convert_coordinates_to_kspace_forward(arr: DataType) -> xr.Dataset | None:
momentum_compatibles: list[str] = list(all_indexes.keys())
momentum_compatibles.sort()
if not momentum_compatibles:
return None
msg = "Cannot convert because no momentum compatible coordinate"
raise RuntimeError(msg)
dest_coords = {
("phi",): ["kp", "kz"],
("theta",): ["kp", "kz"],
Expand Down Expand Up @@ -443,7 +458,7 @@ def broadcast_by_dim_location(
# else we are dealing with an actual array
the_slice = [None] * len(target_shape)
the_slice[dim_location] = slice(None, None, None)
print(dim_location)
logger.info(dim_location)
return np.asarray(data)[the_slice]

raw_coords = {
Expand Down
12 changes: 6 additions & 6 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from collections import OrderedDict
from collections.abc import Collection, Sequence
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias, Unpack
from typing import TYPE_CHECKING, Any, Literal, NoReturn, Self, TypeAlias, Unpack

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -316,7 +316,7 @@ def hv(self) -> float | None:
return 5.93
return None

def fetch_ref_attrs(self) -> Incomplete:
def fetch_ref_attrs(self) -> dict[str, Any]:
"""Get reference attrs."""
if "ref_attrs" in self._obj.attrs:
return self._obj.attrs
Expand Down Expand Up @@ -3370,7 +3370,7 @@ def is_spatial(self) -> bool:
return self.spectrum.S.is_spatial

@property
def spectrum(self) -> xr.DataArray | None:
def spectrum(self) -> xr.DataArray | NoReturn:
"""Isolates a single spectrum from a dataset.
This is a convenience method which is typically used in startup for
Expand Down Expand Up @@ -3577,15 +3577,15 @@ def energy_notation(self) -> EnergyNotation:
.. Note:: The "Kinetic" energy refers to the Fermi level. (not Vacuum level)
"""
if "energy_notation" in self._obj.attrs:
if self.S.spectrum.attrs["energy_notation"] in (
if self.spectrum.attrs["energy_notation"] in (
"Kinetic",
"kinetic",
"kinetic energy",
):
self.S.spectrum.attrs["energy_notation"] = "Kinetic"
self.spectrum.attrs["energy_notation"] = "Kinetic"
return "Kinetic"
return "Binding"
self.S.spectrum.attrs["energy_notation"] = self.S.spectrum.attrs.get(
self.spectrum.attrs["energy_notation"] = self.spectrum.attrs.get(
"energy_notation",
"Binding",
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ ignore = [
"N802", # invalid-function-name (N802)
"N806", # non-lowercase-variable-in-function
"N999", # invalid-module-name (N999)
"S101",
"S101", # assert (S101)
"TD002", # missing-todo-author
"TD003", # missing-todo-link
"PD011", # pandas-use-of-dot-values
Expand Down

0 comments on commit 36ad17f

Please sign in to comment.