From 6671e10faf73884f6499a1c2ae8d45a7f4dbb067 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sat, 14 Oct 2023 18:36:39 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20=20Separate=20np.nan=20from=20NO?= =?UTF-8?q?NE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 💬 update type hints 🔨 black procedure is removed from leftfook.yml --- arpes/_typing.py | 2 +- arpes/config.py | 2 +- arpes/endstations/__init__.py | 142 +++++++++++------- arpes/endstations/fits_utils.py | 29 ++-- arpes/endstations/nexus_utils.py | 4 +- arpes/endstations/plugin/BL10_SARPES.py | 15 +- .../plugin/Elettra_spectromicroscopy.py | 4 +- arpes/endstations/plugin/MAESTRO.py | 4 +- arpes/endstations/plugin/SPD_main.py | 2 - arpes/endstations/plugin/merlin.py | 51 ++++--- arpes/plotting/basic_tools/__init__.py | 14 +- arpes/plotting/fermi_surface.py | 28 +++- arpes/plotting/fit_tool/__init__.py | 4 +- arpes/plotting/stack_plot.py | 3 +- arpes/plotting/utils.py | 116 ++++++-------- arpes/provenance.py | 2 +- arpes/utilities/collections.py | 6 +- arpes/xarray_extensions.py | 2 +- lefthook.yml | 40 ++++- tests/test_basic_data_loading.py | 4 +- 20 files changed, 273 insertions(+), 201 deletions(-) diff --git a/arpes/_typing.py b/arpes/_typing.py index 1ec2f441..cafb0d38 100644 --- a/arpes/_typing.py +++ b/arpes/_typing.py @@ -181,7 +181,7 @@ class SAMPLEINFO(TypedDict, total=False): reflectivity: float | None -class WORKSPACETYPE(TypedDict, total=True): +class WORKSPACETYPE(TypedDict, total=False): path: str | Path name: str diff --git a/arpes/config.py b/arpes/config.py index b8706d2f..b4befd3f 100644 --- a/arpes/config.py +++ b/arpes/config.py @@ -45,7 +45,7 @@ ureg = pint.UnitRegistry() -DATA_PATH = None +DATA_PATH: str | None = None SOURCE_ROOT = str(Path(__file__).parent) SETTINGS: ConfigSettings = { diff --git a/arpes/endstations/__init__.py b/arpes/endstations/__init__.py index f2f8564a..607c70e8 100644 --- a/arpes/endstations/__init__.py +++ b/arpes/endstations/__init__.py @@ -6,8 +6,9 @@ import os.path import re import warnings +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, ClassVar, TypedDict +from typing import TYPE_CHECKING, ClassVar, NoReturn, Self, TypedDict import h5py import numpy as np @@ -30,6 +31,7 @@ from _typeshed import Incomplete from arpes._typing import SPECTROMETER + __all__ = [ "endstation_name_from_alias", "endstation_from_alias", @@ -43,14 +45,28 @@ "resolve_endstation", ] -_ENDSTATION_ALIASES = {} +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[0] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + +_ENDSTATION_ALIASES: dict[str, type[EndstationBase]] = {} class SCANDESC(TypedDict, total=False): - file: str + file: str | Path location: str - path: str + path: str | Path note: dict[str, str | float] # used as attrs basically. + id: int | str class EndstationBase: @@ -83,7 +99,7 @@ class EndstationBase: ATTR_TRANSFORMS: ClassVar[dict[str, str]] = {} MERGE_ATTRS: ClassVar[SPECTROMETER] = {} - _SEARCH_DIRECTORIES = ( + _SEARCH_DIRECTORIES: tuple[str, ...] = ( "", "hdf5", "fits", @@ -91,7 +107,7 @@ class EndstationBase: "../Data/hdf5", "../Data/fits", ) - _SEARCH_PATTERNS = ( + _SEARCH_PATTERNS: tuple[str, ...] = ( r"[\-a-zA-Z0-9_\w]+_[0]+{}$", r"[\-a-zA-Z0-9_\w]+_{}$", r"[\-a-zA-Z0-9_\w]+{}$", @@ -145,13 +161,13 @@ def __init__(self) -> None: def is_file_accepted( cls: type[EndstationBase], file: str | Path, - scan_desc: dict[str, str], + scan_desc: SCANDESC, ) -> bool: """Determines whether this loader can load this file.""" if Path(file).exists() and len(str(file).split(os.path.sep)) > 1: # looks like an actual file, we are going to just check that the extension is kosher # and that the filename matches something reasonable. - p = Path(str(file)) + p = Path(file) if p.suffix not in cls._TOLERATED_EXTENSIONS: return False @@ -163,7 +179,7 @@ def is_file_accepted( return False try: - _ = cls.find_first_file(file, scan_desc) + _ = cls.find_first_file(str(file), scan_desc) except ValueError: return False return True @@ -179,11 +195,11 @@ def files_for_search(cls: type[EndstationBase], directory: str | Path) -> list[s @classmethod def find_first_file( cls: type[EndstationBase], - file, - scan_desc, + file: str, + scan_desc: SCANDESC, *, allow_soft_match: bool = False, - ): + ) -> NoReturn | None: """Attempts to find file associated to the scan given the user provided path or scan number. This is mostly done by regex matching over available options. @@ -197,11 +213,12 @@ def find_first_file( """ workspace = arpes.config.CONFIG["WORKSPACE"] workspace_path = os.path.join(workspace["path"], "data") - workspace = workspace["name"] + workspace_name = workspace["name"] - base_dir = workspace_path or os.path.join(arpes.config.DATA_PATH, workspace) + base_dir = workspace_path or Path(arpes.config.DATA_PATH) / workspace_name dir_options = [os.path.join(base_dir, option) for option in cls._SEARCH_DIRECTORIES] - + logger.debug(f"arpes.config.DATA_PATH: {arpes.config.DATA_PATH}") + logger.debug(f"dir_options: {dir_options}") # another plugin related option here is we can restrict the number of regexes by allowing # plugins to install regexes for particular endstations, if this is needed in the future it # might be a good way of preventing clashes where there is ambiguity in file naming scheme @@ -239,7 +256,11 @@ def find_first_file( msg = f"Could not find file associated to {file}" raise ValueError(msg) - def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict[str, str] | None = None): + def concatenate_frames( + self, + frames: list[xr.Dataset], + scan_desc: SCANDESC | None = None, + ) -> xr.Dataset: """Performs concatenation of frames in multi-frame scans. The way this happens is that we look for an axis on which the frames are changing uniformly @@ -271,17 +292,19 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict[str, str] frames.sort(key=lambda x: x.coords[scan_coord]) return xr.concat(frames, scan_coord) - def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[str]: + def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path] | NoReturn: """Determine all files and frames associated to this piece of data. This always needs to be overridden in subclasses to handle data appropriately. """ + if scan_desc: + msg = "You need to define resolve_frame_locations or subclass SingleFileEndstation." msg = "You need to define resolve_frame_locations or subclass SingleFileEndstation." raise NotImplementedError(msg) def load_single_frame( self, - frame_path: str = "", # TODO should be str and default is "" + frame_path: str | Path = "", scan_desc: SCANDESC | None = None, **kwargs: Incomplete, ) -> xr.Dataset: @@ -289,6 +312,10 @@ def load_single_frame( This always needs to be overridden in subclasses to handle data appropriately. """ + if scan_desc: + logger.debug(scan_desc) + if kwargs: + logger.debug(kwargs) return xr.Dataset() def postprocess(self, frame: xr.Dataset) -> xr.Dataset: @@ -370,25 +397,25 @@ def postprocess_final( data.spectrum.attrs["spectrum_type"] = spectrum_type ls = [data, *data.S.spectra] - for _ in ls: + for a_data in ls: for k, key_fn in self.ATTR_TRANSFORMS.items(): - if k in _.attrs: - transformed = key_fn(_.attrs[k]) + if k in a_data.attrs: + transformed = key_fn(a_data.attrs[k]) if isinstance(transformed, dict): - _.attrs.update(transformed) + a_data.attrs.update(transformed) else: - _.attrs[k] = transformed + a_data.attrs[k] = transformed - for _ in ls: + for a_data in ls: for k, v in self.MERGE_ATTRS.items(): - if k not in _.attrs: - _.attrs[k] = v + if k not in a_data.attrs: + a_data.attrs[k] = v - for _ in ls: + for a_data in ls: for c in self.ENSURE_COORDS_EXIST: - if c not in _.coords: - if c in _.attrs: - _.coords[c] = _.attrs[c] + if c not in a_data.coords: + if c in a_data.attrs: + a_data.coords[c] = a_data.attrs[c] else: warnings_msg = f"Could not assign coordinate {c} from attributes," warnings_msg += "assigning np.nan instead." @@ -396,11 +423,11 @@ def postprocess_final( warnings_msg, stacklevel=2, ) - _.coords[c] = np.nan + a_data.coords[c] = np.nan - for _ in ls: - if "chi" in _.coords and "chi_offset" not in _.attrs: - _.attrs["chi_offset"] = _.coords["chi"].item() + for a_data in ls: + if "chi" in a_data.coords and "chi_offset" not in a_data.attrs: + a_data.attrs["chi_offset"] = a_data.coords["chi"].item() # go and change endianness and datatypes to something reasonable # this is done for performance reasons in momentum space conversion, primarily @@ -503,7 +530,7 @@ class SESEndstation(EndstationBase): These files have special frame names, at least at the beamlines Conrad has encountered. """ - def resolve_frame_locations(self, scan_desc: SCANDESC | None = None): + def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path] | NoReturn: if scan_desc is None: msg = "Must pass dictionary as file scan_desc to all endstation loading code." raise ValueError( @@ -526,12 +553,13 @@ def resolve_frame_locations(self, scan_desc: SCANDESC | None = None): def load_single_frame( self, - frame_path: str = "", + frame_path: str | Path = "", scan_desc: SCANDESC | None = None, **kwargs: Incomplete, ) -> xr.Dataset: ext = Path(frame_path).suffix - + if scan_desc is None: + scan_desc = {} if "nc" in ext: # was converted to hdf5/NetCDF format with Conrad's Igor scripts scan_desc = copy.deepcopy(scan_desc) @@ -542,7 +570,7 @@ def load_single_frame( pxt_data = negate_energy(read_single_pxt(frame_path)) return xr.Dataset({"spectrum": pxt_data}, attrs=pxt_data.attrs) - def postprocess(self, frame: xr.Dataset): + def postprocess(self, frame: xr.Dataset) -> Self: frame = super().postprocess(frame) return frame.assign_attrs(frame.S.spectrum.attrs) @@ -566,6 +594,9 @@ def load_SES_nc( Returns: Loaded data. """ + if kwargs: + for k, v in kwargs.items(): + logger.info(f"load_SES_nc: unused kwargs, k: {k}, value : {v}") if scan_desc is None: scan_desc = {} scan_desc = copy.deepcopy(scan_desc) @@ -591,21 +622,20 @@ def load_SES_nc( # Use dimension labels instead of dimension_labels = list(f["/" + primary_dataset_name].attrs["IGORWaveDimensionLabels"][0]) if any(x == "" for x in dimension_labels): - print(dimension_labels) + logger.info(dimension_labels) if not robust_dimension_labels: msg = "Missing dimension labels. Use robust_dimension_labels=True to override" raise ValueError( msg, ) - else: - used_blanks = 0 - for i in range(len(dimension_labels)): - if dimension_labels[i] == "": - dimension_labels[i] = f"missing{used_blanks}" - used_blanks += 1 + used_blanks = 0 + for i in range(len(dimension_labels)): + if dimension_labels[i] == "": + dimension_labels[i] = f"missing{used_blanks}" + used_blanks += 1 - print(dimension_labels) + logger.info(dimension_labels) scaling = f["/" + primary_dataset_name].attrs["IGORWaveScaling"][-len(dimension_labels) :] raw_data = f["/" + primary_dataset_name][:] @@ -619,13 +649,13 @@ def load_SES_nc( attrs = scan_desc.pop("note", {}) attrs.update(wave_note) - built_coords = dict(zip(dimension_labels, scaling)) + built_coords = dict(zip(dimension_labels, scaling, strict=True)) deg_to_rad_coords = {"theta", "beta", "phi", "alpha", "psi"} # the hemisphere axis is handled below built_coords = { - k: c * (np.pi / 180) if k in deg_to_rad_coords else c for k, c in built_coords.items() + k: np.deg2rad(c) if k in deg_to_rad_coords else c for k, c in built_coords.items() } deg_to_rad_attrs = {"theta", "beta", "alpha", "psi", "chi"} @@ -642,7 +672,7 @@ def load_SES_nc( provenance_from_file( dataset_contents["spectrum"], - data_loc, + str(data_loc), {"what": "Loaded SES dataset from HDF5.", "by": "load_SES"}, ) @@ -708,14 +738,12 @@ class FITSEndstation(EndstationBase): "LMOTOR6": "alpha", } - def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path]: - """These are stored as single files, so just use the one from the description.""" + def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path] | NoReturn: if scan_desc is None: msg = "Must pass dictionary as file scan_desc to all endstation loading code." raise ValueError( msg, ) - original_data_loc = scan_desc.get("path", scan_desc.get("file")) assert original_data_loc is not None assert original_data_loc != "" @@ -991,7 +1019,7 @@ def endstation_name_from_alias(alias) -> str: return endstation_from_alias(alias).PRINCIPAL_NAME -def add_endstation(endstation_cls) -> None: +def add_endstation(endstation_cls: type[EndstationBase]) -> None: """Registers a data loading plugin (Endstation class) together with its aliases. You can use this to add a plugin after the original search if it is defined in another @@ -1051,12 +1079,12 @@ def load_scan( scan_desc: dict[str, str], *, retry: bool = True, - trace: Callable = None, # noqa: RUF013 + trace: Trace | None = None, **kwargs: Incomplete, ) -> xr.Dataset: """Resolves a plugin and delegates loading a scan. - This is used interally by `load_data` and should not be invoked directly + This is used internally by `load_data` and should not be invoked directly by users. Determines which data loading class is appropriate for the data, @@ -1078,7 +1106,7 @@ def load_scan( full_note.update(note) endstation_cls = resolve_endstation(retry=retry, **full_note) - trace(f"Using plugin class {endstation_cls}") + trace(f"Using plugin class {endstation_cls}") if trace else None key = "file" if "file" in scan_desc else "path" @@ -1091,7 +1119,7 @@ def load_scan( except ValueError: pass - trace(f"Loading {scan_desc}") + trace(f"Loading {scan_desc}") if trace else None endstation = endstation_cls() endstation.trace = trace return endstation.load(scan_desc, trace=trace, **kwargs) diff --git a/arpes/endstations/fits_utils.py b/arpes/endstations/fits_utils.py index 97583c2c..0d3837b6 100644 --- a/arpes/endstations/fits_utils.py +++ b/arpes/endstations/fits_utils.py @@ -5,14 +5,14 @@ import warnings from ast import literal_eval from collections.abc import Callable, Iterable -from logging import DEBUG, Formatter, StreamHandler, getLogger +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING, Any, TypeAlias import numpy as np from numpy import ndarray from numpy._typing import NDArray -from arpes.trace import traceable +from arpes.trace import Trace, traceable from arpes.utilities.funcutils import collect_leaves, iter_leaves if TYPE_CHECKING: @@ -23,7 +23,8 @@ "find_clean_coords", ) -LOGLEVEL = DEBUG +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[0] logger = getLogger(__name__) fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" formatter = Formatter(fmt) @@ -55,7 +56,7 @@ def extract_coords( attrs: dict[str, Any], dimension_renamings: dict[str, str] | None = None, - trace: Callable = None, # noqa: RUF013 + trace: Trace | None = None, ) -> tuple[CoordsDict, list[Dimension], list[int]]: """Does the hard work of extracting coordinates from the scan description. @@ -73,7 +74,7 @@ def extract_coords( try: n_loops = attrs["LWLVLPN"] - trace(f"Found n_loops={n_loops}") + trace(f"Found n_loops={n_loops}") if trace else None except KeyError: # Looks like no scan, this happens for instance in the SToF when you take a single # EDC @@ -87,9 +88,9 @@ def extract_coords( scan_coords = {} for loop in range(n_loops): n_scan_dimensions = attrs[f"NMSBDV{loop}"] - trace(f"Considering loop {loop}, n_scan_dimensions={n_scan_dimensions}") + trace(f"Considering loop {loop}, n_scan_dimensions={n_scan_dimensions}") if trace else None if attrs[f"SCNTYP{loop}"] == 0: - trace("Loop is computed") + trace("Loop is computed") if trace else None for i in range(n_scan_dimensions): name, start, end, n = ( attrs[f"NM_{loop}_{i}"], @@ -116,13 +117,13 @@ def extract_coords( # # As of 2021, that is the perspective we are taking on the issue. if n_scan_dimensions > 1: - trace("Loop is tabulated and is not region based") + trace("Loop is tabulated and is not region based") if trace else None for i in range(n_scan_dimensions): name = attrs[f"NM_{loop}_{i}"] if f"ST_{loop}_{i}" not in attrs and f"PV_{loop}_{i}_0" in attrs: msg = f"Determined that coordinate {name} " msg += "is tabulated based on scan coordinate. Skipping!" - trace(msg) + trace(msg) if trace else None continue start, end, n = ( float(attrs[f"ST_{loop}_{i}"]), @@ -132,14 +133,14 @@ def extract_coords( old_name = name name = dimension_renamings.get(name, name) - trace(f"Renaming: {old_name} -> {name}") + trace(f"Renaming: {old_name} -> {name}") if trace else None scan_dimension.append(name) scan_shape.append(n) scan_coords[name] = np.linspace(start, end, n, endpoint=True) else: - trace("Loop is tabulated and is region based") + trace("Loop is tabulated and is region based") if trace else None name, n = ( attrs[f"NM_{loop}_0"], attrs[f"NMPOS_{loop}"], @@ -157,9 +158,9 @@ def extract_coords( n_regions = 1 name = dimension_renamings.get(name, name) - trace(f"Loop (name, n_regions, size) = {(name, n_regions, n)}") + trace(f"Loop (name, n_regions, size) = {(name, n_regions, n)}") if trace else None - coord = np.array(()) + coord: NDArray[np.float_] = np.array(()) for region in range(n_regions): start, end, n = ( attrs[f"ST_{loop}_{region}"], @@ -169,7 +170,7 @@ def extract_coords( msg = f"Reading coordinate {region} from loop. (start, end, n)" msg += f"{(start, end, n)}" - trace(msg) + trace(msg) if trace else None coord = np.concatenate((coord, np.linspace(start, end, n, endpoint=True))) diff --git a/arpes/endstations/nexus_utils.py b/arpes/endstations/nexus_utils.py index e0393fad..2a5118aa 100644 --- a/arpes/endstations/nexus_utils.py +++ b/arpes/endstations/nexus_utils.py @@ -18,8 +18,8 @@ __all__ = ["read_data_attributes_from"] -def read_group_data(group, attribute_name=None) -> Any: - if attribute_name is not None: +def read_group_data(group: dict, attribute_name: str = "") -> Any: + if attribute_name: try: data = group[attribute_name]["data"] except ValueError: diff --git a/arpes/endstations/plugin/BL10_SARPES.py b/arpes/endstations/plugin/BL10_SARPES.py index 49f02c90..a519aa43 100644 --- a/arpes/endstations/plugin/BL10_SARPES.py +++ b/arpes/endstations/plugin/BL10_SARPES.py @@ -69,14 +69,17 @@ class BL10012SARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SE def load_single_frame( self, - frame_path: str = "", + frame_path: str | Path = "", scan_desc: SCANDESC | None = None, **kwargs: Incomplete, - ): + ) -> xr.Dataset: """Loads all regions for a single .pxt frame, and perform per-frame normalization.""" from arpes.load_pxt import find_ses_files_associated, read_single_pxt + if scan_desc is None: + scan_desc = {} original_data_loc = scan_desc.get("path", scan_desc.get("file")) + assert isinstance(original_data_loc, str | Path) p = Path(original_data_loc) @@ -109,16 +112,14 @@ def load_single_frame( def load_single_region( self, - region_path: str | None = None, + region_path: str | Path = "", scan_desc: SCANDESC | None = None, **kwargs: Incomplete, - ): + ) -> xr.Dataset: """Loads a single region for multi-region scans.""" - import os - from arpes.load_pxt import read_single_pxt - name, _ = os.path.splitext(region_path) + name, _ = Path(region_path).stem num = name[-3:] pxt_data = read_single_pxt(region_path, allow_multiple=True) diff --git a/arpes/endstations/plugin/Elettra_spectromicroscopy.py b/arpes/endstations/plugin/Elettra_spectromicroscopy.py index d3e053e7..1f3f900e 100644 --- a/arpes/endstations/plugin/Elettra_spectromicroscopy.py +++ b/arpes/endstations/plugin/Elettra_spectromicroscopy.py @@ -145,12 +145,12 @@ def files_for_search(cls, directory): filter(lambda f: os.path.splitext(f)[1] in cls._TOLERATED_EXTENSIONS, base_files), ) - ANALYZER_INFORMATION: ClassVar[dict[str, str | None | bool]] = { + ANALYZER_INFORMATION: ClassVar[dict[str, str | float | bool]] = { "analyzer": "Custom: in vacuum hemispherical", "analyzer_name": "Spectromicroscopy analyzer", "parallel_deflectors": False, "perpendicular_deflectors": False, - "analyzer_radius": None, + "analyzer_radius": np.nan, "analyzer_type": "hemispherical", } diff --git a/arpes/endstations/plugin/MAESTRO.py b/arpes/endstations/plugin/MAESTRO.py index a721c39e..298ddc99 100644 --- a/arpes/endstations/plugin/MAESTRO.py +++ b/arpes/endstations/plugin/MAESTRO.py @@ -85,7 +85,7 @@ class MAESTROMicroARPESEndstation(MAESTROARPESEndstationBase): "analyzer_name": "Scienta R4000", "parallel_deflectors": False, "perpendicular_deflectors": True, - "analyzer_radius": None, + "analyzer_radius": np.nan, "analyzer_type": "hemispherical", } @@ -176,7 +176,7 @@ class MAESTRONanoARPESEndstation(MAESTROARPESEndstationBase): "analyzer_name": "Scienta DA-30", "parallel_deflectors": False, "perpendicular_deflectors": False, - "analyzer_radius": None, + "analyzer_radius": np.nan, "analyzer_type": "hemispherical", } diff --git a/arpes/endstations/plugin/SPD_main.py b/arpes/endstations/plugin/SPD_main.py index 18d65f66..27d4face 100644 --- a/arpes/endstations/plugin/SPD_main.py +++ b/arpes/endstations/plugin/SPD_main.py @@ -146,8 +146,6 @@ def load_single_frame( def is_dim_coords_same(a: xr.DataArray, b: xr.DataArray) -> bool: """Returns true if the coords used in dims are same in two DataArray.""" - assert isinstance(a, xr.DataArray) - assert isinstance(b, xr.DataArray) try: return all(np.array_equal(a.coords[dim], b.coords[dim]) for dim in a.dims) except KeyError: diff --git a/arpes/endstations/plugin/merlin.py b/arpes/endstations/plugin/merlin.py index 7639530b..1a696e00 100644 --- a/arpes/endstations/plugin/merlin.py +++ b/arpes/endstations/plugin/merlin.py @@ -8,7 +8,12 @@ import numpy as np import xarray as xr -from arpes.endstations import HemisphericalEndstation, SESEndstation, SynchrotronEndstation +from arpes.endstations import ( + SCANDESC, + HemisphericalEndstation, + SESEndstation, + SynchrotronEndstation, +) if TYPE_CHECKING: from _typeshed import Incomplete @@ -81,7 +86,7 @@ class BL403ARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SESEn "analyzer_name": "Scienta R8000", "parallel_deflectors": False, "perpendicular_deflectors": False, - "analyzer_radius": None, + "analyzer_radius": np.nan, "analyzer_type": "hemispherical", "repetition_rate": 5e8, "undulator_harmonic": 2, # TODO: @@ -101,16 +106,21 @@ class BL403ARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SESEn }, } - def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = None): + def concatenate_frames( + self, + frames: list[xr.Dataset], + scan_desc: SCANDESC | None = None, + ) -> xr.Dataset: """Concatenates frames from different files into a single scan. Above standard process here, we need to look for a Motor_Pos.txt file which contains the coordinates of the scanned axis so that we can stitch the different elements together. """ - if len(frames) < 2: + if len(frames) < 2: # noqa: PLR2004 return super().concatenate_frames(frames) - + if scan_desc is None: + scan_desc = {} # determine which axis to stitch them together along, and then do this original_filename = scan_desc.get("file", scan_desc.get("path")) assert original_filename is not None @@ -132,7 +142,7 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = N axis_name = self.RENAME_KEYS.get(axis_name, axis_name) values = [float(_.strip()) for _ in lines[1 : len(frames) + 1]] - for v, f in zip(values, frames): + for v, f in zip(values, frames, strict=True): f.coords[axis_name] = v frames.sort(key=lambda x: x.coords[axis_name]) @@ -159,25 +169,26 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = N def load_single_frame( self, - frame_path: str | None = None, - scan_desc: dict | None = None, + frame_path: str | Path = "", + scan_desc: SCANDESC | None = None, **kwargs: Incomplete, - ): + ) -> xr.Dataset: """Loads all regions for a single .pxt frame, and perform per-frame normalization.""" import copy - import os from arpes.load_pxt import find_ses_files_associated, read_single_pxt from arpes.repair import negate_energy - _, ext = os.path.splitext(frame_path) + if scan_desc is None: + scan_desc = {} + ext = Path(frame_path).suffix if "nc" in ext: # was converted to hdf5/NetCDF format with Conrad's Igor scripts scan_desc = copy.deepcopy(scan_desc) scan_desc["path"] = frame_path return self.load_SES_nc(scan_desc=scan_desc, **kwargs) - original_data_loc = scan_desc.get("path", scan_desc.get("file")) + original_data_loc: Path | str = scan_desc.get("path", scan_desc.get("file")) p = Path(original_data_loc) @@ -210,17 +221,15 @@ def load_single_frame( def load_single_region( self, - region_path: str | None = None, - scan_desc: dict | None = None, + region_path: str | Path = "", + scan_desc: SCANDESC | None = None, **kwargs: Incomplete, - ): + ) -> xr.Dataset: """Loads a single region for multi-region scans.""" - import os - from arpes.load_pxt import read_single_pxt from arpes.repair import negate_energy - name, ext = os.path.splitext(region_path) + name = Path(region_path).stem num = name[-3:] pxt_data = negate_energy(read_single_pxt(region_path)) @@ -232,7 +241,11 @@ def load_single_region( attrs=pxt_data.attrs, ) # separate spectra for possibly unrelated data - def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None): + def postprocess_final( + self, + data: xr.Dataset, + scan_desc: SCANDESC | None = None, + ) -> xr.Dataset: """Performs final data normalization for MERLIN data. Additional steps we perform here are: diff --git a/arpes/plotting/basic_tools/__init__.py b/arpes/plotting/basic_tools/__init__.py index f6e2a016..d87112df 100644 --- a/arpes/plotting/basic_tools/__init__.py +++ b/arpes/plotting/basic_tools/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NoReturn import numpy as np import pyqtgraph as pg @@ -20,8 +20,8 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from PySicde6.QtWidgets import QLayout - from PySide6.QtCore import Point + from pyqtgraph import Point + from PySide6.QtWidgets import QLayout from arpes._typing import DataType @@ -42,8 +42,8 @@ class CoreToolWindow(SimpleWindow): def compile_key_bindings(self) -> list[KeyBinding]: return [ *super().compile_key_bindings(), - KeyBinding("Transpose - Roll Axis", [QtCore.Qt.Key_T], self.transpose_roll), - KeyBinding("Transpose - Swap Front Axes", [QtCore.Qt.Key_Y], self.transpose_swap), + KeyBinding("Transpose - Roll Axis", [QtCore.Qt.Key.Key_T], self.transpose_roll), + KeyBinding("Transpose - Swap Front Axes", [QtCore.Qt.Key.Key_Y], self.transpose_swap), ] def transpose_roll(self, event) -> None: @@ -137,7 +137,7 @@ def compute_path_from_roi(self, roi: pg.PolyLineROI) -> list[Point]: def path(self) -> list[Point]: return self.compute_path_from_roi(self.roi) - def roi_changed(self, _): + def roi_changed(self, _: Incomplete) -> None: with contextlib.suppress(Exception): self.path_changed(self.path) @@ -279,7 +279,7 @@ def add_controls(self) -> None: pass -def wrap(cls: type) -> Callable[..., object]: +def wrap(cls: type) -> Callable[[DataType], object]: def tool_function(data: DataType) -> object: tool = cls() tool.set_data(data) diff --git a/arpes/plotting/fermi_surface.py b/arpes/plotting/fermi_surface.py index 9a947154..f9c48e62 100644 --- a/arpes/plotting/fermi_surface.py +++ b/arpes/plotting/fermi_surface.py @@ -2,6 +2,7 @@ from __future__ import annotations +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING import holoviews as hv @@ -33,6 +34,18 @@ "magnify_circular_regions_plot", ) +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + @save_plot_provenance def fermi_surface_slices( @@ -73,7 +86,7 @@ def magnify_circular_regions_plot( magnified_points: NDArray[np.float_] | list[float], mag: float = 10, radius: float = 0.05, - # below this can be treated as kwargs? + # below this two can be treated as kwargs? cmap: Colormap | ColorType = "viridis", color: ColorType | None = None, edgecolor: ColorType = "red", @@ -143,8 +156,8 @@ def magnify_circular_regions_plot( pts[:, 1] = (pts[:, 1]) / (xlim[1] - xlim[0]) pts[:, 0] = (pts[:, 0]) / (ylim[1] - ylim[0]) - print(np.min(pts[:, 1]), np.max(pts[:, 1])) - print(np.min(pts[:, 0]), np.max(pts[:, 0])) + logger.debug(np.min(pts[:, 1]), np.max(pts[:, 1])) + logger.debug(np.min(pts[:, 0]), np.max(pts[:, 0])) for c, ec, point in zip(color, edgecolor, magnified_points, strict=True): patch = matplotlib.patches.Ellipse( @@ -172,7 +185,14 @@ def magnify_circular_regions_plot( aspect = ax.get_aspect() extent = (xlim[0], xlim[1], ylim[0], ylim[1]) - ax.imshow(data_masked.values, cmap=cm, extent=extent, zorder=3, clim=clim, origin="lower") + ax.imshow( + data_masked.values, + cmap=cm, + extent=extent, + zorder=3, + clim=clim, + origin="lower", + ) ax.set_aspect(aspect) for spine in ["left", "top", "right", "bottom"]: diff --git a/arpes/plotting/fit_tool/__init__.py b/arpes/plotting/fit_tool/__init__.py index e1fe9e6f..52fc9b83 100644 --- a/arpes/plotting/fit_tool/__init__.py +++ b/arpes/plotting/fit_tool/__init__.py @@ -83,7 +83,7 @@ def transpose_swap(self, event) -> None: self.app().transpose_to_front(1) @staticmethod - def _update_scroll_delta(delta, event: QtGui.QKeyEvent): + def _update_scroll_delta(delta: tuple[int, int], event: QtGui.QKeyEvent) -> tuple[int, int]: if event.nativeModifiers() & 1: # shift key delta = (delta[0], delta[1] * 5) @@ -114,7 +114,7 @@ def scroll(self, event: QtGui.QKeyEvent) -> None: QtCore.Qt.Key.Key_Up: (1, 1), } - delta = self._update_scroll_delta(key_map.get(event.key()), event) + delta = self._update_scroll_delta(key_map.get(event.key(), (0, 0)), event) if delta is not None and self.app() is not None: self.app().scroll(delta) diff --git a/arpes/plotting/stack_plot.py b/arpes/plotting/stack_plot.py index 77eaf490..bc1688a5 100644 --- a/arpes/plotting/stack_plot.py +++ b/arpes/plotting/stack_plot.py @@ -321,7 +321,8 @@ def stack_dispersion_plot( # noqa: PLR0913 color(RGBAColorType | Colormap): color of the plot mode(Literal["liine", "fill_between", "hide_line", "scatter"]): Draw mode offset_correction(Literal["zero", "constant", "constant_right"] | None): offset correction - mode (default to "zero") + mode (default to + "zero") shift(float): shift of the plot along the horizontal direction negate(bool): _description_ **kwargs: diff --git a/arpes/plotting/utils.py b/arpes/plotting/utils.py index 922149eb..03782945 100644 --- a/arpes/plotting/utils.py +++ b/arpes/plotting/utils.py @@ -6,16 +6,15 @@ import errno import itertools import json -import os.path import pickle import re import warnings from collections import Counter -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Generator, Iterable, Iterator, Sequence from datetime import UTC -from logging import DEBUG, Formatter, StreamHandler, getLogger +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Literal, Unpack +from typing import TYPE_CHECKING, Any, Literal, Unpack import matplotlib as mpl import matplotlib.pyplot as plt @@ -108,7 +107,8 @@ "h_gradient_fill", ) -LOGLEVEL = DEBUG +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] logger = getLogger(__name__) fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" formatter = Formatter(fmt) @@ -124,7 +124,7 @@ @contextlib.contextmanager -def unchanged_limits(ax: Axes): +def unchanged_limits(ax: Axes) -> Iterator[None]: """Context manager that retains axis limits.""" xlim, ylim = ax.get_xlim(), ax.get_ylim() @@ -177,7 +177,7 @@ def h_gradient_fill( x1(float): lower side of x x2(float): height side of x x_solid: If x_solid is not None, the gradient will be extended at the maximum opacity from - the closer limit towards x_solid. + the closer limit towards x_solid. fill_color (str): Color name, pass it as "c" in mpl.colors.to_rgb ax(Axes): matplotlib Axes object **kwargs: Pass to imshow (Z order can be set here.) @@ -672,6 +672,10 @@ def imshow_mask( """Plots a mask by using a fixed color and transparency.""" assert over is not None + if ax is None: + ax = plt.gca() + assert isinstance(ax, Axes) + default_kwargs = { "origin": "lower", "aspect": ax.get_aspect(), @@ -685,9 +689,6 @@ def imshow_mask( default_kwargs.update(kwargs) kwargs = default_kwargs - if ax is None: - ax = plt.gca() - assert isinstance(ax, Axes) if isinstance(kwargs["cmap"], str): kwargs["cmap"] = mpl.colormaps.get_cmap(cmap=kwargs["cmap"]) @@ -870,7 +871,7 @@ def resolve(name: str, value: slice | int) -> NDArray[np.float_]: if missing_dims: assert reference_data is not None - print(missing_dims) + logger.info(missing_dims) if n_cut_dims == TwoDimensional: # a region cut, illustrate with a rect or by suppressing background @@ -956,14 +957,10 @@ def generic_colorbar( ax(Axes): Matplotlib Axes object **kwargs: Pass to ColoarbarBase """ - default_kwargs: ColorbarParam = { - "cmap": mpl.colormaps.get_cmap("Blues"), - "norm": colors.Normalize(vmin=low, vmax=high), - "ticks": [low, high], - "orientation": "horizontal", - } - default_kwargs.update(kwargs) - kwargs = default_kwargs + kwargs.setdefault("cmap", mpl.colormaps.get_cmap("Blues")) + kwargs.setdefault("norm", colors.Normalize(vmin=low, vmax=high)) + kwargs.setdefault("ticks", [low, high]) + kwargs.setdefault("orientation", "horizontal") delta = high - low low = low - delta / 6 @@ -982,15 +979,11 @@ def phase_angle_colorbar( assert isinstance(ax, Axes) assert "use_tex" in SETTINGS - default_kwargs = { - "cmap": mpl.colormaps.get_cmap("Blue_r"), - "norm": colors.Normalize(vmin=low, vmax=high), - "label": "Angle (rad)", - "ticks": ["0", r"$\pi$", r"$2\pi$"], - "orientation": "horizontal", - } - default_kwargs.update(kwargs) - kwargs = default_kwargs + kwargs.setdefault("cmap", mpl.colormaps.get_cmap("Blues_r")) + kwargs.setdefault("norm", colors.Normalize(vmin=low, vmax=high)) + kwargs.setdefault("label", "Angle (rad)") + kwargs.setdefault("ticks", ["0", r"$\pi$", r"$2\pi$"]) + kwargs.setdefault("orientation", "horizontal") if not SETTINGS["use_tex"]: kwargs["ticks"] = ["0", "π", "2π"] @@ -1006,15 +999,11 @@ def temperature_colorbar( ) -> colorbar.Colorbar: """Generates a colorbar suitable for temperature data with fixed extent.""" assert isinstance(ax, Axes) - default_kwargs = { - "cmap": "Blues_r", - "norm": colors.Normalize(vmin=low, vmax=high), - "label": "Temperature (K)", - "ticks": [low, high], - "orientation": "horizontal", - } - default_kwargs.update(kwargs) - kwargs = default_kwargs + kwargs.setdefault("cmap", mpl.colormaps.get_cmap("Blues_r")) + kwargs.setdefault("norm", colors.Normalize(vmin=low, vmax=high)) + kwargs.setdefault("label", "Temperature (K)") + kwargs.setdefault("ticks", [low, high]) + kwargs.setdefault("orientation", "horizontal") if isinstance(kwargs["cmap"], str): kwargs["cmap"] = mpl.colormaps.get_cmap(kwargs["cmap"]) @@ -1036,16 +1025,11 @@ def delay_colorbar( TODO make this nonsequential for use in case where you want to have a long time period after the delay or before. """ - default_kwargs = { - "cmap": mpl.colormaps.get_cmap("coolwarm"), - "norm": colors.Normalize(vmin=low, vmax=high), - "label": "Probe pulse delay (fs)", - "ticks": [low, 0, high], - "orientation": "horizontal", - } - default_kwargs.update(kwargs) - kwargs = default_kwargs - + kwargs.setdefault("cmap", mpl.colormaps.get_cmap("coolwarm")) + kwargs.setdefault("norm", colors.Normalize(vmin=low, vmax=high)) + kwargs.setdefault("label", "Probe pulse delay (fs)") + kwargs.setdefault("ticks", [low, 0, high]) + kwargs.setdefault("orientation", "horizontal") return colorbar.Colorbar(ax, **kwargs) @@ -1057,20 +1041,17 @@ def temperature_colorbar_around( ) -> colorbar.Colorbar: """Generates a colorbar suitable for temperature axes around a central value.""" assert isinstance(ax, Axes) - - default_kwargs = { - "cmap": mpl.colormaps.get_cmap("RdBu_r"), - "norm": colors.Normalize( + kwargs.setdefault("cmap", mpl.colormaps.get_cmap("RdBu_r")) + kwargs.setdefault( + "norm", + colors.Normalize( vmin=central - temperature_range, vmax=central + temperature_range, ), - "label": "Temperature (K)", - "orientation": "horizontal", - "ticks": [central - temperature_range, central + temperature_range], - } - - default_kwargs.update(kwargs) - kwargs = default_kwargs + ) + kwargs.setdefault("label", "Temperature (K)") + kwargs.setdefault("ticks", [central - temperature_range, central + temperature_range]) + kwargs.setdefault("orientation", "horizontal") return colorbar.Colorbar(ax, **kwargs) @@ -1122,7 +1103,7 @@ def remove_colorbars(fig: Figure | None = None) -> None: try: if fig is not None: for ax in fig.axes: - if ax.get_aspect() == 20: # a bit of a hack + if ax.get_aspect() >= 20: # a bit of a hack ax.remove() else: remove_colorbars(plt.gcf()) @@ -1268,9 +1249,9 @@ def load_data_for_figure(p: str | Path) -> None: def savefig( - desired_path: str, + desired_path: str | Path, dpi: int = 400, - data=None, + data: list[DataType] | tuple[DataType, ...] | set[DataType] | None = None, save_data=None, *, paper: bool = False, @@ -1288,7 +1269,9 @@ def savefig( after the fact if you have many many plots. """ - if not os.path.splitext(desired_path)[1]: + desired_path = Path(desired_path) + assert isinstance(desired_path, Path) + if not desired_path.suffix: paper = True if save_data is None: @@ -1299,7 +1282,7 @@ def savefig( msg, ) else: - output_location = path_for_plot(os.path.splitext(desired_path)[0]) + output_location = path_for_plot(desired_path.parent / desired_path.stem) with Path(str(output_location) + ".pickle").open("wb") as f: pickle.dump(save_data, f) @@ -1330,11 +1313,8 @@ def savefig( "name": "savefig", } - def extract(for_data): - try: - return for_data.attrs.get("provenance", {}) - except Exception: - return {} + def extract(for_data: DataType) -> dict[str, Any]: + return for_data.attrs.get("provenance", {}) if data is not None: assert isinstance( diff --git a/arpes/provenance.py b/arpes/provenance.py index b71b13ea..76718392 100644 --- a/arpes/provenance.py +++ b/arpes/provenance.py @@ -55,7 +55,7 @@ def attach_id(data: DataType) -> None: data.attrs["id"] = str(uuid.uuid1()) -def provenance_from_file(child_arr: DataType, file: str, record: str) -> None: +def provenance_from_file(child_arr: DataType, file: str, record: dict[str, str | float]) -> None: """Builds a provenance entry for a dataset corresponding to loading data from a file. This is used by data loaders at the start of an analysis. diff --git a/arpes/utilities/collections.py b/arpes/utilities/collections.py index 143c38bd..54e5ad11 100644 --- a/arpes/utilities/collections.py +++ b/arpes/utilities/collections.py @@ -86,14 +86,16 @@ def deep_equals( | tuple[str, ...] | tuple[float, ...] | set[str | float] - | dict[str, float | str], + | dict[str, float | str] + | None, b: float | str | list[float | str] | tuple[str, ...] | tuple[float, ...] | set[str | float] - | dict[str, float | str], + | dict[str, float | str] + | None, ) -> bool | None: """An equality check that looks into common collection types.""" if not isinstance(b, type(a)): diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 7b4993dc..8cfeebbe 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -1810,7 +1810,7 @@ def _repr_html_spectrometer_info(self) -> str: return ARPESAccessorBase.dict_to_html(ordered_settings) @staticmethod - def _repr_html_experimental_conditions(conditions): + def _repr_html_experimental_conditions(conditions: dict) -> str: transforms = { "polarization": lambda p: { "p": "Linear Horizontal", diff --git a/lefthook.yml b/lefthook.yml index 508bed0e..af073d45 100644 --- a/lefthook.yml +++ b/lefthook.yml @@ -1,6 +1,34 @@ -pre-commit: - parallel: true - commands: - black: - glob: "*.py" - run: yarn check-black {staged_files} +# EXAMPLE USAGE +# Refer for explanation to following link: +# https://github.com/evilmartians/lefthook/blob/master/docs/full_guide.md +# +# pre-push: +# commands: +# packages-audit: +# tags: frontend security +# run: yarn audit +# gems-audit: +# tags: backend security +# run: bundle audit +# +# pre-commit: +# parallel: true +# commands: +# eslint: +# glob: "*.{js,ts}" +# run: yarn eslint {staged_files} +# rubocop: +# tags: backend style +# glob: "*.rb" +# exclude: "application.rb|routes.rb" +# run: bundle exec rubocop --force-exclusion {all_files} +# govet: +# tags: backend style +# files: git ls-files -m +# glob: "*.go" +# run: go vet {files} +# scripts: +# "hello.js": +# runner: node +# "any.go": +# runner: go run diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py index e7a68b84..ef5b4f21 100644 --- a/tests/test_basic_data_loading.py +++ b/tests/test_basic_data_loading.py @@ -164,7 +164,7 @@ class TestMetadata: "name": "Scienta R8000", "parallel_deflectors": False, "perpendicular_deflectors": False, - "radius": None, + "radius": np.nan, "type": "hemispherical", }, }, @@ -251,7 +251,7 @@ class TestMetadata: "analyzer": "R4000", "analyzer_detail": { "type": "hemispherical", - "radius": None, + "radius": np.nan, "name": "Scienta R4000", "parallel_deflectors": False, "perpendicular_deflectors": True,