Skip to content

Commit

Permalink
🚨 Introduce helper function.
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 19, 2024
1 parent 945f3d2 commit 8e230e0
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions src/arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def concatenate_frames(
frames.sort(key=lambda x: x.coords[scan_coord])
return xr.concat(frames, scan_coord)

def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path | str]:
def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path]:
"""Determine all files and frames associated to this piece of data.
This always needs to be overridden in subclasses to handle data appropriately.
Expand Down Expand Up @@ -358,31 +358,12 @@ def postprocess_final(
coord_names: tuple[str, ...] = tuple(sorted([str(c) for c in data.dims if c != "cycle"]))
spectrum_type = _spectrum_type(coord_names)

if "phi" not in data.coords:
data.coords["phi"] = 0
for s in data.S.spectra:
s.coords["phi"] = 0

if spectrum_type is not None:
data.attrs["spectrum_type"] = spectrum_type
if "spectrum" in data.data_vars:
data.spectrum.attrs["spectrum_type"] = spectrum_type

ls = [data, *data.S.spectra]
for a_data in ls:
for k, key_fn in self.ATTR_TRANSFORMS.items():
if k in a_data.attrs:
transformed = key_fn(a_data.attrs[k])
if isinstance(transformed, dict):
a_data.attrs.update(transformed)
else:
a_data.attrs[k] = transformed

for a_data in ls:
for k, v in self.MERGE_ATTRS.items():
a_data.attrs.setdefault(k, v)

for a_data in [_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in ls]:
modified_data = [
self._modify_a_data(a_data, spectrum_type) for a_data in [data, *data.S.spectra]
]
for a_data in [
_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in modified_data
]:
if "chi" in a_data.coords and "chi_offset" not in a_data.attrs:
a_data.attrs["chi_offset"] = a_data.coords["chi"].item()

Expand Down Expand Up @@ -449,6 +430,30 @@ def load(self, scan_desc: ScanDesc | None = None, **kwargs: Incomplete) -> xr.Da

return concatted

def _modify_a_data(self, a_data: DataType, spectrum_type: str | None) -> DataType:
"""Helper function to modify the Dataset and DataArray that are contained in the Dataset.
Args:
a_data: [TODO:description]
spectrum_type: [TODO:description]
Returns:
[TODO:description]
"""
if "phi" not in a_data.coords:
a_data.coords["phi"] = 0
a_data.attrs["spectrum_type"] = spectrum_type
for k, key_fn in self.ATTR_TRANSFORMS.items():
if k in a_data.attrs:
transformed = key_fn(a_data.attrs[k])
if isinstance(transformed, dict):
a_data.attrs.update(transformed)
else:
a_data.attrs[k] = transformed
for k, v in self.MERGE_ATTRS.items():
a_data.attrs.setdefault(k, v)
return a_data


def _spectrum_type(
coord_names: Sequence[str],
Expand Down

0 comments on commit 8e230e0

Please sign in to comment.