Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 18, 2024
1 parent e228924 commit 330ac67
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 81 deletions.
2 changes: 1 addition & 1 deletion src/arpes/analysis/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def polys_to_mask(


def apply_mask_to_coords(
data: xr.Dataset,
data: xr.Dataset, # data.data_vars is used
mask: dict[str, NDArray[np.float_] | Iterable[Iterable[float]]], # (N, 2) array
dims: list[str],
*,
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def concatenate_frames(
frames.sort(key=lambda x: x.coords[scan_coord])
return xr.concat(frames, scan_coord)

def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path]:
def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path | str]:
"""Determine all files and frames associated to this piece of data.
This always needs to be overridden in subclasses to handle data appropriately.
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/ALG_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import arpes.xarray_extensions # pylint: disable=unused-import, redefined-outer-name # noqa: F401
from arpes.config import ureg
from arpes.endstations import ScanDesc, FITSEndstation, HemisphericalEndstation
from arpes.endstations import FITSEndstation, HemisphericalEndstation, ScanDesc
from arpes.laser import electrons_per_pulse

if TYPE_CHECKING:
Expand Down
17 changes: 16 additions & 1 deletion src/arpes/endstations/plugin/ALG_spin_ToF.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=no-member
import itertools
import warnings
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import ClassVar

Expand All @@ -14,13 +15,26 @@
from astropy.io import fits

import arpes.config
from arpes.endstations import ScanDesc, EndstationBase, find_clean_coords
from arpes.endstations import EndstationBase, ScanDesc, find_clean_coords
from arpes.provenance import Provenance, provenance_from_file
from arpes.utilities import rename_keys

__all__ = ("SpinToFEndstation",)


LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


class SpinToFEndstation(EndstationBase):
"""Implements data loading for the Lanzara group Spin-ToF."""

Expand Down Expand Up @@ -224,6 +238,7 @@ def load_SToF_fits(self, scan_desc: ScanDesc) -> xr.Dataset:
except Exception:
# we should probably zero pad in the case where the slices are not the right
# size
logger.exception("Exception Occure")
continue

altered_dimension = dimensions[spectrum_name][0]
Expand Down
6 changes: 3 additions & 3 deletions src/arpes/endstations/plugin/ANTARES.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def load_top_level_scan(

return ds

def get_coords(self, group, scan_name, shape):
def get_coords(self, group: Incomplete, scan_name: str, shape: Incomplete):
"""Extracts coordinates from the actuator header information.
In the future, this should be modified for data which lacks either a phi or energy axis.
Expand Down Expand Up @@ -247,10 +247,10 @@ def get_first(item):

return item

def build_axis(low, high, step_size) -> tuple[NDArray[np.float_], int]:
def build_axis(low: float, high: float, step_size: float) -> tuple[NDArray[np.float_], int]:
# this might not work out to be the right thing to do, we will see
low, high, step_size = get_first(low), get_first(high), get_first(step_size)
est_n = int((high - low) / step_size)
est_n: int = int((high - low) / step_size)

closest = None
diff = np.inf
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/MBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MBSEndstation(HemisphericalEndstation):
def resolve_frame_locations(
self,
scan_desc: ScanDesc | None = None,
) -> list[Path]:
) -> list[Path | str]:
"""There is only a single file for the MBS loader, so this is simple."""
if scan_desc is None:
scan_desc = {}
Expand Down
12 changes: 6 additions & 6 deletions src/arpes/endstations/plugin/SSRF_NSRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import io
from configparser import ConfigParser
from logging import warning
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar
from zipfile import ZipFile
Expand Down Expand Up @@ -107,18 +107,18 @@ class DA30_L(SingleFileEndstation):

def load_single_frame(
self,
fpath: str | Path = "",
frame_path: str | Path = "",
scan_desc: ScanDesc | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
if kwargs:
warning.warn("Any kwargs is not supported in this function.")
warnings.warn("Any kwargs is not supported in this function.", stacklevel=2)
if scan_desc is None:
scan_desc = {}
file = Path(fpath)
file = Path(frame_path)

if file.suffix == ".pxt":
frame = read_single_pxt(fpath).rename(W="eV", X="phi")
frame = read_single_pxt(frame_path).rename(W="eV", X="phi")
frame = frame.assign_coords(phi=np.deg2rad(frame.phi))

return xr.Dataset(
Expand All @@ -127,7 +127,7 @@ def load_single_frame(
)

if file.suffix == ".zip":
zf = ZipFile(fpath)
zf = ZipFile(frame_path)
viewer_ini_ziped = zf.open("viewer.ini", "r")
viewer_ini_io = io.TextIOWrapper(viewer_ini_ziped)
viewer_ini = ConfigParser(strict=False)
Expand Down
13 changes: 13 additions & 0 deletions src/arpes/endstations/plugin/igor_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import ClassVar

Expand All @@ -16,6 +17,18 @@

__all__ = ("IgorExportEndstation",)

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


class IgorExportEndstation(SESEndstation):
"""Implements loading exported HDF files for ARPES data from Igor."""
Expand Down
6 changes: 3 additions & 3 deletions src/arpes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def load_data(
file = str(Path(file).absolute())

desc: ScanDesc = {
"file": file,
"location": location,
"file": file, # type:ignore[typeddict-item]
"location": location, # type:ignore[typeddict-item]
}

if location is None:
Expand Down Expand Up @@ -258,7 +258,7 @@ def file_for_pickle(name: str) -> Path | str:
def load_pickle(name: str) -> object:
"""Loads a workspace local pickle. Inverse to `save_pickle`."""
with Path(file_for_pickle(name)).open("rb") as file:
return pickle.load(file)
return pickle.load(file) # noqa: S301


def save_pickle(data: object, name: str) -> None:
Expand Down
57 changes: 39 additions & 18 deletions src/arpes/models/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -15,15 +16,25 @@
import lmfit as lf
from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = [
"Band",
"MultifitBand",
"VoigtBand",
"BackgroundBand",
]

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


class Band:
"""Representation of an ARPES band which supports some calculations after fitting."""
Expand All @@ -32,12 +43,12 @@ def __init__(
self,
label: str,
display_label: str | None = None,
data: DataType | None = None,
data: xr.Dataset | None = None,
) -> None:
"""Set the data but don't perform any calculation eagerly."""
self.label = label
self._display_label = display_label
self._data = data
self._data: xr.Dataset | None = data

@property
def display_label(self) -> str:
Expand All @@ -53,8 +64,6 @@ def display_label(self, value: str) -> None:
def velocity(self) -> xr.DataArray:
"""The band velocity.
[TODO:description]
Args:
self ([TODO:type]): [TODO:description]
Expand All @@ -81,7 +90,11 @@ def embed_nan(values: NDArray[np.float_], padding: int) -> NDArray[np.float_]:
nan_mask = scipy.ndimage.gaussian_filter(nan_mask, sigma, mode="mirror")
masked = scipy.ndimage.gaussian_filter(masked, sigma, mode="mirror")

return xr.DataArray(np.gradient(masked / nan_mask, spacing)[50:-50], self.coords, self.dims)
return xr.DataArray(
np.gradient(masked / nan_mask, spacing)[50:-50],
self.coords.values.tolist(),
self.dims,
)

@property
def fermi_velocity(self) -> xr.DataArray:
Expand Down Expand Up @@ -110,12 +123,12 @@ def get_dataarray(
clean: bool = True,
) -> xr.DataArray | NDArray[np.float_]:
"""Converts the underlying data into an array representation."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
assert isinstance(self._data, xr.Dataset)
if not clean:
return self._data[var_name].values

output = np.copy(self._data[var_name].values)
output[self._data[var_name + "_stderr"].values > 0.01] = np.nan
output[self._data[var_name + "_stderr"].values > 0.01] = np.nan # noqa: PLR2004

return xr.DataArray(
output,
Expand All @@ -126,39 +139,47 @@ def get_dataarray(
@property
def center(self) -> xr.DataArray:
"""Gets the peak location along the band."""
return self.get_dataarray("center")
center_array = self.get_dataarray("center")
assert isinstance(center_array, xr.DataArray)
return center_array

@property
def center_stderr(self) -> xr.DataArray:
def center_stderr(self) -> NDArray[np.float_]:
"""Gets the peak location stderr along the band."""
return self.get_dataarray("center_stderr", clean=False)
center_stderr = self.get_dataarray("center_stderr", clean=False)
assert isinstance(center_stderr, np.ndarray)
return center_stderr

@property
def sigma(self) -> xr.DataArray:
"""Gets the peak width along the band."""
return self.get_dataarray("sigma", clean=True)
sigma_array = self.get_dataarray("sigma", clean=True)
assert isinstance(sigma_array, xr.DataArray)
return sigma_array

@property
def amplitude(self) -> xr.DataArray:
"""Gets the peak amplitude along the band."""
return self.get_dataarray("amplitude", clean=True)
amplitude_array = self.get_dataarray("amplitude", clean=True)
assert isinstance(amplitude_array, xr.DataArray)
return amplitude_array

@property
def indexes(self):
"""Fetches the indices of the originating data (after fit reduction)."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
assert isinstance(self._data, xr.Dataset)
return self._data.center.indexes

@property
def coords(self) -> xr.DataArray:
"""Fetches the coordinates of the originating data (after fit reduction)."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
assert isinstance(self._data, xr.Dataset)
return self._data.center.coords

@property
def dims(self) -> tuple[str, ...]:
"""Fetches the dimensions of the originating data (after fit reduction)."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
assert isinstance(self._data, xr.Dataset)
return self._data.center.dims


Expand All @@ -167,7 +188,7 @@ class MultifitBand(Band):

def get_dataarray(self, var_name: str):
"""Converts the underlying data into an array representation."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
assert isinstance(self._data, xr.Dataset)
full_var_name = self.label + var_name

if "stderr" in full_var_name:
Expand Down
9 changes: 5 additions & 4 deletions src/arpes/plotting/basic_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import numpy as np
import pyqtgraph as pg
import xarray as xr
from PySide6 import QtCore, QtWidgets
from scipy import interpolate

from arpes import analysis
from arpes.constants import TWO_DIMENSION
from arpes.utilities import normalize_to_spectrum
from arpes.utilities.conversion import DetectorCalibration
from arpes.utilities.qt import BasicHelpDialog, SimpleApp, SimpleWindow, qt_info
Expand All @@ -20,13 +22,12 @@
if TYPE_CHECKING:
from collections.abc import Callable, Hashable, Sequence

import xarray as xr
from _typeshed import Incomplete
from numpy.typing import NDArray
from pyqtgraph import Point
from PySide6.QtWidgets import QGridLayout

from arpes._typing import DataType, XrTypes
from arpes._typing import DataType

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
Expand Down Expand Up @@ -91,7 +92,7 @@ def layout(self) -> QGridLayout:
return self.main_layout

def set_data(self, data: xr.DataArray) -> None:
self.data = normalize_to_spectrum(data)
self.data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)

def transpose_to_front(self, dim: Hashable) -> None:
order = list(self.data.dims)
Expand Down Expand Up @@ -185,7 +186,7 @@ class PathTool(CoreTool):

def path_changed(self, path: NDArray[np.float_]) -> None:
selected_data = self.data.S.along(path)
if len(selected_data.dims) == 2: # noqa: PLR2004
if len(selected_data.dims) == TWO_DIMENSION:
self.views["P"].setImage(selected_data.data.transpose())
else:
self.views["P"].clear()
Expand Down
Loading

0 comments on commit 330ac67

Please sign in to comment.