Skip to content

Commit

Permalink
🔨 Separate np.nan from NONE
Browse files Browse the repository at this point in the history
💬  update type hints
🔨  black procedure is removed from leftfook.yml
  • Loading branch information
arafune committed Oct 15, 2023
1 parent 247777e commit 6671e10
Show file tree
Hide file tree
Showing 20 changed files with 273 additions and 201 deletions.
2 changes: 1 addition & 1 deletion arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class SAMPLEINFO(TypedDict, total=False):
reflectivity: float | None


class WORKSPACETYPE(TypedDict, total=True):
class WORKSPACETYPE(TypedDict, total=False):
path: str | Path
name: str

Expand Down
2 changes: 1 addition & 1 deletion arpes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

ureg = pint.UnitRegistry()

DATA_PATH = None
DATA_PATH: str | None = None
SOURCE_ROOT = str(Path(__file__).parent)

SETTINGS: ConfigSettings = {
Expand Down
142 changes: 85 additions & 57 deletions arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import os.path
import re
import warnings
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, TypedDict
from typing import TYPE_CHECKING, ClassVar, NoReturn, Self, TypedDict

import h5py
import numpy as np
Expand All @@ -30,6 +31,7 @@
from _typeshed import Incomplete

from arpes._typing import SPECTROMETER

__all__ = [
"endstation_name_from_alias",
"endstation_from_alias",
Expand All @@ -43,14 +45,28 @@
"resolve_endstation",
]

_ENDSTATION_ALIASES = {}
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[0]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


_ENDSTATION_ALIASES: dict[str, type[EndstationBase]] = {}


class SCANDESC(TypedDict, total=False):
file: str
file: str | Path
location: str
path: str
path: str | Path
note: dict[str, str | float] # used as attrs basically.
id: int | str


class EndstationBase:
Expand Down Expand Up @@ -83,15 +99,15 @@ class EndstationBase:
ATTR_TRANSFORMS: ClassVar[dict[str, str]] = {}
MERGE_ATTRS: ClassVar[SPECTROMETER] = {}

_SEARCH_DIRECTORIES = (
_SEARCH_DIRECTORIES: tuple[str, ...] = (
"",
"hdf5",
"fits",
"../Data",
"../Data/hdf5",
"../Data/fits",
)
_SEARCH_PATTERNS = (
_SEARCH_PATTERNS: tuple[str, ...] = (
r"[\-a-zA-Z0-9_\w]+_[0]+{}$",
r"[\-a-zA-Z0-9_\w]+_{}$",
r"[\-a-zA-Z0-9_\w]+{}$",
Expand Down Expand Up @@ -145,13 +161,13 @@ def __init__(self) -> None:
def is_file_accepted(
cls: type[EndstationBase],
file: str | Path,
scan_desc: dict[str, str],
scan_desc: SCANDESC,
) -> bool:
"""Determines whether this loader can load this file."""
if Path(file).exists() and len(str(file).split(os.path.sep)) > 1:
# looks like an actual file, we are going to just check that the extension is kosher
# and that the filename matches something reasonable.
p = Path(str(file))
p = Path(file)

if p.suffix not in cls._TOLERATED_EXTENSIONS:
return False
Expand All @@ -163,7 +179,7 @@ def is_file_accepted(

return False
try:
_ = cls.find_first_file(file, scan_desc)
_ = cls.find_first_file(str(file), scan_desc)
except ValueError:
return False
return True
Expand All @@ -179,11 +195,11 @@ def files_for_search(cls: type[EndstationBase], directory: str | Path) -> list[s
@classmethod
def find_first_file(
cls: type[EndstationBase],
file,
scan_desc,
file: str,
scan_desc: SCANDESC,
*,
allow_soft_match: bool = False,
):
) -> NoReturn | None:
"""Attempts to find file associated to the scan given the user provided path or scan number.
This is mostly done by regex matching over available options.
Expand All @@ -197,11 +213,12 @@ def find_first_file(
"""
workspace = arpes.config.CONFIG["WORKSPACE"]
workspace_path = os.path.join(workspace["path"], "data")
workspace = workspace["name"]
workspace_name = workspace["name"]

base_dir = workspace_path or os.path.join(arpes.config.DATA_PATH, workspace)
base_dir = workspace_path or Path(arpes.config.DATA_PATH) / workspace_name
dir_options = [os.path.join(base_dir, option) for option in cls._SEARCH_DIRECTORIES]

logger.debug(f"arpes.config.DATA_PATH: {arpes.config.DATA_PATH}")
logger.debug(f"dir_options: {dir_options}")
# another plugin related option here is we can restrict the number of regexes by allowing
# plugins to install regexes for particular endstations, if this is needed in the future it
# might be a good way of preventing clashes where there is ambiguity in file naming scheme
Expand Down Expand Up @@ -239,7 +256,11 @@ def find_first_file(
msg = f"Could not find file associated to {file}"
raise ValueError(msg)

def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict[str, str] | None = None):
def concatenate_frames(
self,
frames: list[xr.Dataset],
scan_desc: SCANDESC | None = None,
) -> xr.Dataset:
"""Performs concatenation of frames in multi-frame scans.
The way this happens is that we look for an axis on which the frames are changing uniformly
Expand Down Expand Up @@ -271,24 +292,30 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict[str, str]
frames.sort(key=lambda x: x.coords[scan_coord])
return xr.concat(frames, scan_coord)

def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[str]:
def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path] | NoReturn:
"""Determine all files and frames associated to this piece of data.
This always needs to be overridden in subclasses to handle data appropriately.
"""
if scan_desc:
msg = "You need to define resolve_frame_locations or subclass SingleFileEndstation."
msg = "You need to define resolve_frame_locations or subclass SingleFileEndstation."
raise NotImplementedError(msg)

def load_single_frame(
self,
frame_path: str = "", # TODO<RA> should be str and default is ""
frame_path: str | Path = "",
scan_desc: SCANDESC | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
"""Hook for loading a single frame of data.
This always needs to be overridden in subclasses to handle data appropriately.
"""
if scan_desc:
logger.debug(scan_desc)
if kwargs:
logger.debug(kwargs)
return xr.Dataset()

def postprocess(self, frame: xr.Dataset) -> xr.Dataset:
Expand Down Expand Up @@ -370,37 +397,37 @@ def postprocess_final(
data.spectrum.attrs["spectrum_type"] = spectrum_type

ls = [data, *data.S.spectra]
for _ in ls:
for a_data in ls:
for k, key_fn in self.ATTR_TRANSFORMS.items():
if k in _.attrs:
transformed = key_fn(_.attrs[k])
if k in a_data.attrs:
transformed = key_fn(a_data.attrs[k])
if isinstance(transformed, dict):
_.attrs.update(transformed)
a_data.attrs.update(transformed)
else:
_.attrs[k] = transformed
a_data.attrs[k] = transformed

for _ in ls:
for a_data in ls:
for k, v in self.MERGE_ATTRS.items():
if k not in _.attrs:
_.attrs[k] = v
if k not in a_data.attrs:
a_data.attrs[k] = v

for _ in ls:
for a_data in ls:
for c in self.ENSURE_COORDS_EXIST:
if c not in _.coords:
if c in _.attrs:
_.coords[c] = _.attrs[c]
if c not in a_data.coords:
if c in a_data.attrs:
a_data.coords[c] = a_data.attrs[c]
else:
warnings_msg = f"Could not assign coordinate {c} from attributes,"
warnings_msg += "assigning np.nan instead."
warnings.warn(
warnings_msg,
stacklevel=2,
)
_.coords[c] = np.nan
a_data.coords[c] = np.nan

for _ in ls:
if "chi" in _.coords and "chi_offset" not in _.attrs:
_.attrs["chi_offset"] = _.coords["chi"].item()
for a_data in ls:
if "chi" in a_data.coords and "chi_offset" not in a_data.attrs:
a_data.attrs["chi_offset"] = a_data.coords["chi"].item()

# go and change endianness and datatypes to something reasonable
# this is done for performance reasons in momentum space conversion, primarily
Expand Down Expand Up @@ -503,7 +530,7 @@ class SESEndstation(EndstationBase):
These files have special frame names, at least at the beamlines Conrad has encountered.
"""

def resolve_frame_locations(self, scan_desc: SCANDESC | None = None):
def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path] | NoReturn:
if scan_desc is None:
msg = "Must pass dictionary as file scan_desc to all endstation loading code."
raise ValueError(
Expand All @@ -526,12 +553,13 @@ def resolve_frame_locations(self, scan_desc: SCANDESC | None = None):

def load_single_frame(
self,
frame_path: str = "",
frame_path: str | Path = "",
scan_desc: SCANDESC | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
ext = Path(frame_path).suffix

if scan_desc is None:
scan_desc = {}
if "nc" in ext:
# was converted to hdf5/NetCDF format with Conrad's Igor scripts
scan_desc = copy.deepcopy(scan_desc)
Expand All @@ -542,7 +570,7 @@ def load_single_frame(
pxt_data = negate_energy(read_single_pxt(frame_path))
return xr.Dataset({"spectrum": pxt_data}, attrs=pxt_data.attrs)

def postprocess(self, frame: xr.Dataset):
def postprocess(self, frame: xr.Dataset) -> Self:
frame = super().postprocess(frame)
return frame.assign_attrs(frame.S.spectrum.attrs)

Expand All @@ -566,6 +594,9 @@ def load_SES_nc(
Returns:
Loaded data.
"""
if kwargs:
for k, v in kwargs.items():
logger.info(f"load_SES_nc: unused kwargs, k: {k}, value : {v}")
if scan_desc is None:
scan_desc = {}
scan_desc = copy.deepcopy(scan_desc)
Expand All @@ -591,21 +622,20 @@ def load_SES_nc(
# Use dimension labels instead of
dimension_labels = list(f["/" + primary_dataset_name].attrs["IGORWaveDimensionLabels"][0])
if any(x == "" for x in dimension_labels):
print(dimension_labels)
logger.info(dimension_labels)

if not robust_dimension_labels:
msg = "Missing dimension labels. Use robust_dimension_labels=True to override"
raise ValueError(
msg,
)
else:
used_blanks = 0
for i in range(len(dimension_labels)):
if dimension_labels[i] == "":
dimension_labels[i] = f"missing{used_blanks}"
used_blanks += 1
used_blanks = 0
for i in range(len(dimension_labels)):
if dimension_labels[i] == "":
dimension_labels[i] = f"missing{used_blanks}"
used_blanks += 1

print(dimension_labels)
logger.info(dimension_labels)

scaling = f["/" + primary_dataset_name].attrs["IGORWaveScaling"][-len(dimension_labels) :]
raw_data = f["/" + primary_dataset_name][:]
Expand All @@ -619,13 +649,13 @@ def load_SES_nc(
attrs = scan_desc.pop("note", {})
attrs.update(wave_note)

built_coords = dict(zip(dimension_labels, scaling))
built_coords = dict(zip(dimension_labels, scaling, strict=True))

deg_to_rad_coords = {"theta", "beta", "phi", "alpha", "psi"}

# the hemisphere axis is handled below
built_coords = {
k: c * (np.pi / 180) if k in deg_to_rad_coords else c for k, c in built_coords.items()
k: np.deg2rad(c) if k in deg_to_rad_coords else c for k, c in built_coords.items()
}

deg_to_rad_attrs = {"theta", "beta", "alpha", "psi", "chi"}
Expand All @@ -642,7 +672,7 @@ def load_SES_nc(

provenance_from_file(
dataset_contents["spectrum"],
data_loc,
str(data_loc),
{"what": "Loaded SES dataset from HDF5.", "by": "load_SES"},
)

Expand Down Expand Up @@ -708,14 +738,12 @@ class FITSEndstation(EndstationBase):
"LMOTOR6": "alpha",
}

def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path]:
"""These are stored as single files, so just use the one from the description."""
def resolve_frame_locations(self, scan_desc: SCANDESC | None = None) -> list[Path] | NoReturn:
if scan_desc is None:
msg = "Must pass dictionary as file scan_desc to all endstation loading code."
raise ValueError(
msg,
)

original_data_loc = scan_desc.get("path", scan_desc.get("file"))
assert original_data_loc is not None
assert original_data_loc != ""
Expand Down Expand Up @@ -991,7 +1019,7 @@ def endstation_name_from_alias(alias) -> str:
return endstation_from_alias(alias).PRINCIPAL_NAME


def add_endstation(endstation_cls) -> None:
def add_endstation(endstation_cls: type[EndstationBase]) -> None:
"""Registers a data loading plugin (Endstation class) together with its aliases.
You can use this to add a plugin after the original search if it is defined in another
Expand Down Expand Up @@ -1051,12 +1079,12 @@ def load_scan(
scan_desc: dict[str, str],
*,
retry: bool = True,
trace: Callable = None, # noqa: RUF013
trace: Trace | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
"""Resolves a plugin and delegates loading a scan.
This is used interally by `load_data` and should not be invoked directly
This is used internally by `load_data` and should not be invoked directly
by users.
Determines which data loading class is appropriate for the data,
Expand All @@ -1078,7 +1106,7 @@ def load_scan(
full_note.update(note)

endstation_cls = resolve_endstation(retry=retry, **full_note)
trace(f"Using plugin class {endstation_cls}")
trace(f"Using plugin class {endstation_cls}") if trace else None

key = "file" if "file" in scan_desc else "path"

Expand All @@ -1091,7 +1119,7 @@ def load_scan(
except ValueError:
pass

trace(f"Loading {scan_desc}")
trace(f"Loading {scan_desc}") if trace else None
endstation = endstation_cls()
endstation.trace = trace
return endstation.load(scan_desc, trace=trace, **kwargs)
Loading

0 comments on commit 6671e10

Please sign in to comment.