Skip to content

Commit

Permalink
🔀
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 19, 2024
2 parents ab4b0f5 + fa7b283 commit 0419ce7
Show file tree
Hide file tree
Showing 39 changed files with 263 additions and 140 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ lint.ignore = [
"G004", # logging-f-string
#
"NPY201", # Numpy 2.0,
"ISC001", # single-line-implicit-string-concatenation
]
lint.select = ["ALL"]
target-version = "py311"
Expand Down
26 changes: 19 additions & 7 deletions src/arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
import itertools
from itertools import pairwise
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -36,6 +37,18 @@
"fit_for_effective_mass",
)

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


def fit_for_effective_mass(
data: xr.DataArray,
Expand Down Expand Up @@ -248,7 +261,7 @@ def dataarray_for_value(param_name: str, i: int = i, *, is_value: bool) -> xr.Da


@update_provenance("Fit bands from pattern")
def fit_patterned_bands(
def fit_patterned_bands( # noqa: PLR0913
arr: xr.DataArray,
band_set: dict[Incomplete, Incomplete],
fit_direction: str = "",
Expand Down Expand Up @@ -280,10 +293,9 @@ def fit_patterned_bands(
band_set: dictionary with bands and points along the spectrum
fit_direction (str):
stray (float, optional):
orientation: edc or mdc
direction_normal
preferred_k_direction
dataset: if True, return as Dataset
background (bool):
interactive(bool):
dataset(bool): if true, return as xr.Dataset.
Returns:
Dataset or DataArray, as controlled by the parameter "dataset"
Expand All @@ -296,7 +308,7 @@ def fit_patterned_bands(
free_directions = list(arr.dims)
free_directions.remove(fit_direction)

def resolve_partial_bands_from_description(
def resolve_partial_bands_from_description( # noqa: PLR0913
coord_dict: dict[str, Incomplete],
name: str = "",
band: Incomplete = None,
Expand Down Expand Up @@ -512,7 +524,7 @@ def fit_bands(
# be stable
closest_model_params = initial_fits # fix me
dist = float("inf")
frozen_coordinate = tuple(coordinate[k] for k in template.dims)
frozen_coordinate = tuple(coordinate[str(k)] for k in template.dims)
for c, v in all_fit_parameters.items():
delta = np.array(c) - frozen_coordinate
current_distance = delta.dot(delta)
Expand Down
12 changes: 7 additions & 5 deletions src/arpes/analysis/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,20 @@ def apply_mask_to_coords(
Returns:
The masked data.
"""
p = Path(mask["poly"])

as_array = np.stack([data.data_vars[d].values for d in dims], axis=-1)
shape = as_array.shape
dest_shape = shape[:-1]
new_shape = [np.prod(dest_shape), len(dims)]
mask_array = (
Path(np.array(mask["poly"]))
.contains_points(as_array.reshape(new_shape))
.reshape(dest_shape)
)

mask = p.contains_points(as_array.reshape(new_shape)).reshape(dest_shape)
if invert:
mask = np.logical_not(mask)
mask_array = np.logical_not(mask_array)

return mask
return mask_array


@update_provenance("Apply boolean mask to data")
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def draw_samples(self, n_samples: int = Distribution.DEFAULT_N_SAMPLES) -> NDArr
return scipy.stats.norm.rvs(self.center, scale=self.stderr, size=n_samples)

@classmethod
def from_param(cls: type, model_param: lf.Model.Parameter):
def from_param(cls: type, model_param: lf.Model.Parameter) -> Incomplete:
"""Generates a Normal from an `lmfit.Parameter`."""
return cls(center=model_param.value, stderr=model_param.stderr)

Expand Down Expand Up @@ -353,7 +353,7 @@ def bootstrapped(
n: int = 20,
prior_adjustment: int = 1,
**kwargs: Incomplete,
):
) -> Incomplete:
# examine args to determine which to resample
resample_indices = [
i
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/deep_learning/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def items(self) -> list[InterpretationItem]:
def top_losses(self, *, ascending: bool = False) -> list[InterpretationItem]:
"""Orders the items by loss."""

def key(item: Incomplete):
def key(item: Incomplete) -> Incomplete:
return item.loss if ascending else -item.loss

return sorted(self.items, key=key)
Expand Down
35 changes: 32 additions & 3 deletions src/arpes/deep_learning/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,44 @@ class Identity:
"""Represents a reversible identity transform."""

def encodes(self, x: Incomplete) -> Incomplete:
"""[TODO:summary].
Args:
x: [TODO:description]
Returns:
[TODO:description]
"""
return x

def __call__(self, x: Incomplete) -> Incomplete:
"""[TODO:summary].
Args:
x: [TODO:description]
Returns:
[TODO:description]
"""
return x

def decodes(self, x: Incomplete) -> Incomplete:
"""[TODO:summary].
Args:
x: [TODO:description]
Returns:
[TODO:description]
"""
return x

def __repr__(self) -> str:
"""[TODO:summary].
Returns:
[TODO:description]
"""
return "Identity()"


Expand Down Expand Up @@ -54,9 +83,9 @@ def __post_init__(self) -> None:
for t in self.transforms:
if isinstance(t, tuple | list):
xt, yt = t
t = [xt or _identity, yt or _identity]

safe_transforms.append(t)
safe_transforms.append([xt or _identity, yt or _identity])
else:
safe_transforms.append(t)

self.original_transforms = self.transforms
self.transforms = safe_transforms
Expand Down
57 changes: 31 additions & 26 deletions src/arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def concatenate_frames(
frames.sort(key=lambda x: x.coords[scan_coord])
return xr.concat(frames, scan_coord)

def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path | str]:
def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path]:
"""Determine all files and frames associated to this piece of data.
This always needs to be overridden in subclasses to handle data appropriately.
Expand Down Expand Up @@ -358,31 +358,12 @@ def postprocess_final(
coord_names: tuple[str, ...] = tuple(sorted([str(c) for c in data.dims if c != "cycle"]))
spectrum_type = _spectrum_type(coord_names)

if "phi" not in data.coords:
data.coords["phi"] = 0
for s in data.S.spectra:
s.coords["phi"] = 0

if spectrum_type is not None:
data.attrs["spectrum_type"] = spectrum_type
if "spectrum" in data.data_vars:
data.spectrum.attrs["spectrum_type"] = spectrum_type

ls = [data, *data.S.spectra]
for a_data in ls:
for k, key_fn in self.ATTR_TRANSFORMS.items():
if k in a_data.attrs:
transformed = key_fn(a_data.attrs[k])
if isinstance(transformed, dict):
a_data.attrs.update(transformed)
else:
a_data.attrs[k] = transformed

for a_data in ls:
for k, v in self.MERGE_ATTRS.items():
a_data.attrs.setdefault(k, v)

for a_data in [_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in ls]:
modified_data = [
self._modify_a_data(a_data, spectrum_type) for a_data in [data, *data.S.spectra]
]
for a_data in [
_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in modified_data
]:
if "chi" in a_data.coords and "chi_offset" not in a_data.attrs:
a_data.attrs["chi_offset"] = a_data.coords["chi"].item()

Expand Down Expand Up @@ -449,6 +430,30 @@ def load(self, scan_desc: ScanDesc | None = None, **kwargs: Incomplete) -> xr.Da

return concatted

def _modify_a_data(self, a_data: DataType, spectrum_type: str | None) -> DataType:
"""Helper function to modify the Dataset and DataArray that are contained in the Dataset.
Args:
a_data: [TODO:description]
spectrum_type: [TODO:description]
Returns:
[TODO:description]
"""
if "phi" not in a_data.coords:
a_data.coords["phi"] = 0
a_data.attrs["spectrum_type"] = spectrum_type
for k, key_fn in self.ATTR_TRANSFORMS.items():
if k in a_data.attrs:
transformed = key_fn(a_data.attrs[k])
if isinstance(transformed, dict):
a_data.attrs.update(transformed)
else:
a_data.attrs[k] = transformed
for k, v in self.MERGE_ATTRS.items():
a_data.attrs.setdefault(k, v)
return a_data


def _spectrum_type(
coord_names: Sequence[str],
Expand Down
3 changes: 2 additions & 1 deletion src/arpes/endstations/nexus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if TYPE_CHECKING:
from collections.abc import Callable

from _typeshed import Incomplete
import xarray as xr

__all__ = ("read_data_attributes_from",)
Expand Down Expand Up @@ -63,7 +64,7 @@ class Target:

value: Any = None

def read_h5(self, g, path) -> None:
def read_h5(self, g: Incomplete, path: Incomplete) -> None:
self.value = None
self.value = self.read(read_group_data(g))

Expand Down
31 changes: 22 additions & 9 deletions src/arpes/endstations/plugin/ANTARES.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@
}


def parse_axis_name_from_long_name(name: str, keep_segments: int = 1, separator: str = "_") -> str:
def parse_axis_name_from_long_name(
name: str,
keep_segments: int = 1,
separator: str = "_",
) -> str:
segments = name.split("/")[-keep_segments:]
segments = [s.replace("'", "") for s in segments]
return separator.join(segments)
Expand All @@ -99,14 +103,18 @@ def infer_scan_type_from_data(group: dict) -> str:
raise NotImplementedError(scan_name)


class ANTARESEndstation(HemisphericalEndstation, SynchrotronEndstation, SingleFileEndstation):
class ANTARESEndstation(
HemisphericalEndstation,
SynchrotronEndstation,
SingleFileEndstation,
):
"""Implements data loading for ANTARES at SOLEIL.
There's not too much metadata here except what comes with the analyzer settings.
"""

PRINCIPAL_NAME = "ANTARES"
ALIASES: ClassVar[list] = []
ALIASES: ClassVar[list[str]] = []

_TOLERATED_EXTENSIONS: ClassVar[set[str]] = {".nxs"}

Expand All @@ -120,14 +128,12 @@ def load_top_level_scan(
) -> xr.Dataset:
"""Reads a spectrum from the top level group in a NeXuS scan format.
[TODO:description]
Args:
group ([TODO:type]): [TODO:description]
scan_desc: [TODO:description]
spectrum_index ([TODO:type]): [TODO:description]
Returns:
Returns (xr.Dataset):
[TODO:description]
"""
if scan_desc:
Expand Down Expand Up @@ -177,7 +183,10 @@ def get_coords(self, group: Incomplete, scan_name: str, shape: Incomplete):
(
name
if set_names[name] == 1
else parse_axis_name_from_long_name(actuator_long_names[i], keep_segments)
else parse_axis_name_from_long_name(
actuator_long_names[i],
keep_segments,
)
)
for i, name in enumerate(actuator_names)
]
Expand Down Expand Up @@ -241,13 +250,17 @@ def take_last(vs):
energy = data[e_keys[0]][0], data[e_keys[1]][0], data[e_keys[2]][0]
angle = data[ang_keys[0]][0], data[ang_keys[1]][0], data[ang_keys[2]][0]

def get_first(item):
def get_first(item: NDArray[np.float_] | float):
if isinstance(item, np.ndarray):
return item.ravel()[0]

return item

def build_axis(low: float, high: float, step_size: float) -> tuple[NDArray[np.float_], int]:
def build_axis(
low: float,
high: float,
step_size: float,
) -> tuple[NDArray[np.float_], int]:
# this might not work out to be the right thing to do, we will see
low, high, step_size = get_first(low), get_first(high), get_first(step_size)
est_n: int = int((high - low) / step_size)
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/BL10_SARPES.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def load_single_region(
"""Loads a single region for multi-region scans."""
from arpes.load_pxt import read_single_pxt

name, _ = Path(region_path).stem
name = Path(region_path).stem
num = name[-3:]

pxt_data = read_single_pxt(region_path, allow_multiple=True)
Expand Down
Loading

0 comments on commit 0419ce7

Please sign in to comment.