Skip to content

Commit

Permalink
🔥 Remove trace arg and related file.
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 7, 2024
1 parent 14bb08c commit cb0b6df
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 176 deletions.
31 changes: 11 additions & 20 deletions arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from arpes.load_pxt import find_ses_files_associated, read_single_pxt
from arpes.provenance import PROVENANCE, provenance_from_file
from arpes.repair import negate_energy
from arpes.trace import Trace, traceable
from arpes.utilities.dict import rename_dataarray_attrs

from .fits_utils import find_clean_coords
Expand Down Expand Up @@ -153,11 +152,8 @@ class EndstationBase:

RENAME_KEYS: ClassVar[dict[str, str]] = {}

trace: Trace

def __init__(self) -> None:
"""Initialize."""
self.trace = Trace(silent=True)

@classmethod
def is_file_accepted(
Expand Down Expand Up @@ -455,13 +451,13 @@ def load(self, scan_desc: SCANDESC | None = None, **kwargs: Incomplete) -> xr.Da
"""
if scan_desc is None:
scan_desc = {}
self.trace("Resolving frame locations")
logger.debug("Resolving frame locations")
resolved_frame_locations = self.resolve_frame_locations(scan_desc)
self.trace(f"resolved_frame_locations: {resolved_frame_locations}")
logger.debug(f"resolved_frame_locations: {resolved_frame_locations}")
if not resolved_frame_locations:
msg = "File not found"
raise RuntimeError(msg)
self.trace(f"Found frames: {resolved_frame_locations}")
logger.debug(f"Found frames: {resolved_frame_locations}")
frames = [
self.load_single_frame(fpath, scan_desc, **kwargs) for fpath in resolved_frame_locations
]
Expand Down Expand Up @@ -794,7 +790,7 @@ def load_single_frame(
for k, v in kwargs.items():
logger.debug(f" key {k}: value{v}")
# Use dimension labels instead of
self.trace("Opening FITS HDU list.")
logger.debug("Opening FITS HDU list.")
hdulist = fits.open(frame_path, ignore_missing_end=True)
primary_dataset_name = None

Expand All @@ -805,12 +801,12 @@ def load_single_frame(
del hdulist[i].header["UN_0_0"]
hdulist[i].header["UN_0_0"] = ""
if "TTYPE2" in hdulist[i].header and hdulist[i].header["TTYPE2"] == "Delay":
self.trace("Using ps delay units. This looks like an ALG main chamber scan.")
logger.debug("Using ps delay units. This looks like an ALG main chamber scan.")
hdulist[i].header["TUNIT2"] = ""
del hdulist[i].header["TUNIT2"]
hdulist[i].header["TUNIT2"] = "ps"

self.trace(f"HDU {i}: Attempting to fix FITS errors.")
logger.debug(f"HDU {i}: Attempting to fix FITS errors.")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
hdulist[i].verify("fix+warn")
Expand All @@ -836,9 +832,8 @@ def load_single_frame(
hdu,
attrs,
mode="MC",
trace=self.trace,
)
self.trace("Recovered coordinates from FITS file.")
logger.debug("Recovered coordinates from FITS file.")

attrs = rename_keys(attrs, self.RENAME_KEYS)
scan_desc = rename_keys(scan_desc, self.RENAME_KEYS)
Expand Down Expand Up @@ -987,7 +982,7 @@ def prep_spectrum(data: xr.DataArray) -> xr.DataArray:
k: np.deg2rad(c) if k in deg_to_rad_coords else c for k, c in built_coords.items()
}

self.trace("Stitching together xr.Dataset.")
logger.debug("Stitching together xr.Dataset.")
return xr.Dataset(
{
f"safe-{name}" if name in data_var.coords else name: data_var
Expand Down Expand Up @@ -1086,12 +1081,10 @@ def resolve_endstation(*, retry: bool = True, **kwargs: Incomplete) -> type[Ends
raise ValueError(msg) from key_error


@traceable
def load_scan(
scan_desc: SCANDESC,
*,
retry: bool = True,
trace: Trace | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
"""Resolves a plugin and delegates loading a scan.
Expand All @@ -1106,7 +1099,6 @@ def load_scan(
Args:
scan_desc: Information identifying the scan, typically a scan number or full path.
retry: Used to attempt a reload of plugins and subsequent data load attempt.
trace: Trace instance for debugging, pass True or False (default) to control this parameter
kwargs: pass to the endstation.load(scan_dec, **kwargs)
Returns:
Expand All @@ -1118,7 +1110,7 @@ def load_scan(
full_note.update(note)

endstation_cls = resolve_endstation(retry=retry, **full_note)
trace(f"Using plugin class {endstation_cls}") if trace else None
logger.debug(f"Using plugin class {endstation_cls}")

key: Literal["file", "path"] = "file" if "file" in scan_desc else "path"

Expand All @@ -1130,7 +1122,6 @@ def load_scan(
except ValueError:
pass

trace(f"Loading {scan_desc}") if trace else None
logger.debug(f"Loading {scan_desc}")
endstation = endstation_cls()
endstation.trace = trace
return endstation.load(scan_desc, trace=trace, **kwargs)
return endstation.load(scan_desc, **kwargs)
56 changes: 19 additions & 37 deletions arpes/endstations/fits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import functools
import warnings
from ast import literal_eval
from collections.abc import Callable, Iterable
from collections.abc import Iterable
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any, TypeAlias

import numpy as np
from numpy import ndarray
from numpy._typing import NDArray

from arpes.trace import Trace, traceable
from arpes.utilities.funcutils import collect_leaves, iter_leaves

if TYPE_CHECKING:
Expand Down Expand Up @@ -53,19 +52,15 @@
Dimension = str


@traceable
def extract_coords(
attrs: dict[str, Any],
dimension_renamings: dict[str, str] | None = None,
trace: Trace | None = None,
) -> tuple[CoordsDict, list[Dimension], list[int]]:
"""Does the hard work of extracting coordinates from the scan description.
Args:
attrs:
dimension_renamings:
trace: A Trace instance used for debugging. You can pass True or False (including to the
originating load_data call) to enable execution tracing.
Returns:
A tuple consisting of the coordinate arrays, the dimension names, and their shapes
Expand All @@ -75,7 +70,7 @@ def extract_coords(

try:
n_loops = attrs["LWLVLPN"]
trace(f"Found n_loops={n_loops}") if trace else None
logger.debug(f"Found n_loops={n_loops}")
except KeyError:
# Looks like no scan, this happens for instance in the SToF when you take a single
# EDC
Expand All @@ -89,9 +84,9 @@ def extract_coords(
scan_coords = {}
for loop in range(n_loops):
n_scan_dimensions = attrs[f"NMSBDV{loop}"]
trace(f"Considering loop {loop}, n_scan_dimensions={n_scan_dimensions}") if trace else None
logger.debug(f"Considering loop {loop}, n_scan_dimensions={n_scan_dimensions}")
if attrs[f"SCNTYP{loop}"] == 0:
trace("Loop is computed") if trace else None
logger.debug("Loop is computed")
for i in range(n_scan_dimensions):
name, start, end, n = (
attrs[f"NM_{loop}_{i}"],
Expand All @@ -118,13 +113,13 @@ def extract_coords(
#
# As of 2021, that is the perspective we are taking on the issue.
elif n_scan_dimensions > 1:
trace("Loop is tabulated and is not region based") if trace else None
logger.debug("Loop is tabulated and is not region based")
for i in range(n_scan_dimensions):
name = attrs[f"NM_{loop}_{i}"]
if f"ST_{loop}_{i}" not in attrs and f"PV_{loop}_{i}_0" in attrs:
msg = f"Determined that coordinate {name} "
msg += "is tabulated based on scan coordinate. Skipping!"
trace(msg) if trace else None
logger.debug(msg)
continue
start, end, n = (
float(attrs[f"ST_{loop}_{i}"]),
Expand All @@ -134,14 +129,14 @@ def extract_coords(

old_name = name
name = dimension_renamings.get(name, name)
trace(f"Renaming: {old_name} -> {name}") if trace else None
logger.debug(f"Renaming: {old_name} -> {name}")

scan_dimension.append(name)
scan_shape.append(n)
scan_coords[name] = np.linspace(start, end, n, endpoint=True)

else:
trace("Loop is tabulated and is region based") if trace else None
logger("Loop is tabulated and is region based")
name, n = (
attrs[f"NM_{loop}_0"],
attrs[f"NMPOS_{loop}"],
Expand All @@ -159,7 +154,7 @@ def extract_coords(
n_regions = 1
name = dimension_renamings.get(name, name)

trace(f"Loop (name, n_regions, size) = {(name, n_regions, n)}") if trace else None
logger.debug(f"Loop (name, n_regions, size) = {(name, n_regions, n)}")

coord: NDArray[np.float_] = np.array(())
for region in range(n_regions):
Expand All @@ -171,7 +166,7 @@ def extract_coords(
msg = f"Reading coordinate {region} from loop. (start, end, n)"
msg += f"{(start, end, n)}"

trace(msg) if trace else None
logger.debug(msg)

coord = np.concatenate((coord, np.linspace(start, end, n, endpoint=True)))

Expand All @@ -181,14 +176,12 @@ def extract_coords(
return scan_coords, scan_dimension, scan_shape


@traceable
def find_clean_coords(
hdu: BinTableHDU,
attrs: dict[str, Any],
spectra: Any = None,
mode: str = "ToF",
dimension_renamings: Any = None,
trace: Callable | None = None,
) -> tuple[CoordsDict, dict[str, list[Dimension]], dict[str, Any]]:
"""Determines the scan degrees of freedom, and reads coordinates.
Expand Down Expand Up @@ -224,18 +217,13 @@ def find_clean_coords(
scan_coords, scan_dimension, scan_shape = extract_coords(
attrs,
dimension_renamings=dimension_renamings,
trace=trace,
)
trace(f"Found scan shape {scan_shape} and dimensions {scan_dimension}.") if trace else None
logger.debug(f"Found scan shape {scan_shape} and dimensions {scan_dimension}.")

# bit of a hack to deal with the internal motor used for the swept spectra being considered as
# a cycle
if "cycle" in scan_coords and len(scan_coords["cycle"]) > 200:
(
trace("Renaming swept scan coordinate to cycle and extracting. This is hack.")
if trace
else None
)
logger.debug("Renaming swept scan coordinate to cycle and extracting. This is hack.")
idx = scan_dimension.index("cycle")

real_data_for_cycle = hdu.data.columns["null"].array
Expand All @@ -258,14 +246,14 @@ def find_clean_coords(
spectra = [spectra]

for spectrum_key in spectra:
trace(f"Considering potential spectrum {spectrum_key}") if trace else None
logger.debug(f"Considering potential spectrum {spectrum_key}")
skip_names = {
lambda name: bool("beamview" in name or "IMAQdx" in name),
}

if spectrum_key is None:
spectrum_key = hdu.columns.names[-1]
trace(f"Column name was None, using {spectrum_key}") if trace else None
logger.debug(f"Column name was None, using {spectrum_key}")

if isinstance(spectrum_key, str):
spectrum_key = hdu.columns.names.index(spectrum_key) + 1
Expand All @@ -279,31 +267,25 @@ def find_clean_coords(
if (callable(skipped) and skipped(spectrum_name)) or skipped == spectrum_name:
should_skip = True
if should_skip:
trace("Skipping column.") if trace else None
logger.debug("Skipping column.")
continue

try:
offset = hdu.header[f"TRVAL{spectrum_key}"]
delta = hdu.header[f"TDELT{spectrum_key}"]
offset = literal_eval(offset) if isinstance(offset, str) else offset
delta = literal_eval(delta) if isinstance(delta, str) else delta
trace(f"Determined (offset, delta): {(offset, delta)}.") if trace else None
logger.debug(f"Determined (offset, delta): {(offset, delta)}.")

try:
shape = hdu.header[f"TDIM{spectrum_key}"]
shape = literal_eval(shape) if isinstance(shape, str) else shape
loaded_shape_from_header = True
(
trace(f"Successfully loaded coordinate shape from header: {shape}")
if trace
else None
)
logger.debug(f"Successfully loaded coordinate shape from header: {shape}")
except KeyError:
shape = hdu.data.field(spectrum_key - 1).shape
(
trace(f"Could not use header to determine coordinate shape, using: {shape}")
if trace
else None
logger.debug(
f"Could not use header to determine coordinate shape, using: {shape}",
)

try:
Expand Down
8 changes: 2 additions & 6 deletions arpes/endstations/plugin/fallback.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Implements dynamic plugin selection when users do not specify the location for their data."""

from __future__ import annotations

import warnings
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, ClassVar

from arpes.endstations import EndstationBase, resolve_endstation
from arpes.trace import Trace, traceable

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -59,12 +59,9 @@ class FallbackEndstation(EndstationBase):
]

@classmethod
@traceable
def determine_associated_loader(
cls: type[FallbackEndstation],
file: str | Path,
*,
trace: Trace | None = None,
) -> type[EndstationBase]:
"""Determines which loading plugin to use for a given piece of data.
Expand All @@ -76,7 +73,7 @@ def determine_associated_loader(
arpes.config.load_plugins()

for location in cls.ATTEMPT_ORDER:
trace(f"{cls.__name__} is trying {location}")
logger.debug(f"{cls.__name__} is trying {location}")

try:
endstation_cls = resolve_endstation(retry=False, location=location)
Expand Down Expand Up @@ -104,7 +101,6 @@ def load(
associated_loader = FallbackEndstation.determine_associated_loader(
file,
scan_desc,
trace=self.trace,
)
try:
file_number = int(file)
Expand Down
1 change: 0 additions & 1 deletion arpes/fits/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def broadcast_model(
than 20 fits were requested
progress: Whether to show a progress bar
safe: Whether to mask out nan values
trace: Controls whether execution tracing/timestamping is used for performance investigation
Returns:
An `xr.Dataset` containing the curve fitting results. These are data vars:
Expand Down
3 changes: 1 addition & 2 deletions arpes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def load_data(
),
stacklevel=2,
)
if kwargs.get("trace"):
logger.debug(f"contents of desc: {desc}")
logger.debug(f"contents of desc: {desc}")
return load_scan(desc, **kwargs)


Expand Down
Loading

0 comments on commit cb0b6df

Please sign in to comment.