Skip to content

Commit

Permalink
🚨 Remove ruff Warning in xarray_extensions.py
Browse files Browse the repository at this point in the history
    -  PLW2901
💬  Update type hints
  • Loading branch information
arafune committed Mar 19, 2024
1 parent 8e230e0 commit fa7b283
Show file tree
Hide file tree
Showing 17 changed files with 136 additions and 66 deletions.
14 changes: 8 additions & 6 deletions src/arpes/analysis/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def polys_to_mask(
radius: float = 0,
*,
invert: bool = False,
) -> NDArray[np.float_] | NDArray[np.bool_]:
) -> NDArray[np.bool_]:
"""Converts a mask definition in terms of the underlying polygon to a True/False mask array.
Uses the coordinates and shape of the target data in order to determine which pixels
Expand Down Expand Up @@ -120,18 +120,20 @@ def apply_mask_to_coords(
Returns:
The masked data.
"""
p = Path(mask["poly"])

as_array = np.stack([data.data_vars[d].values for d in dims], axis=-1)
shape = as_array.shape
dest_shape = shape[:-1]
new_shape = [np.prod(dest_shape), len(dims)]
mask_array = (
Path(np.array(mask["poly"]))
.contains_points(as_array.reshape(new_shape))
.reshape(dest_shape)
)

mask = p.contains_points(as_array.reshape(new_shape)).reshape(dest_shape)
if invert:
mask = np.logical_not(mask)
mask_array = np.logical_not(mask_array)

return mask
return mask_array


@update_provenance("Apply boolean mask to data")
Expand Down
3 changes: 2 additions & 1 deletion src/arpes/endstations/nexus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if TYPE_CHECKING:
from collections.abc import Callable

from _typeshed import Incomplete
import xarray as xr

__all__ = ("read_data_attributes_from",)
Expand Down Expand Up @@ -63,7 +64,7 @@ class Target:

value: Any = None

def read_h5(self, g, path) -> None:
def read_h5(self, g: Incomplete, path: Incomplete) -> None:
self.value = None
self.value = self.read(read_group_data(g))

Expand Down
31 changes: 22 additions & 9 deletions src/arpes/endstations/plugin/ANTARES.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@
}


def parse_axis_name_from_long_name(name: str, keep_segments: int = 1, separator: str = "_") -> str:
def parse_axis_name_from_long_name(
name: str,
keep_segments: int = 1,
separator: str = "_",
) -> str:
segments = name.split("/")[-keep_segments:]
segments = [s.replace("'", "") for s in segments]
return separator.join(segments)
Expand All @@ -99,14 +103,18 @@ def infer_scan_type_from_data(group: dict) -> str:
raise NotImplementedError(scan_name)


class ANTARESEndstation(HemisphericalEndstation, SynchrotronEndstation, SingleFileEndstation):
class ANTARESEndstation(
HemisphericalEndstation,
SynchrotronEndstation,
SingleFileEndstation,
):
"""Implements data loading for ANTARES at SOLEIL.
There's not too much metadata here except what comes with the analyzer settings.
"""

PRINCIPAL_NAME = "ANTARES"
ALIASES: ClassVar[list] = []
ALIASES: ClassVar[list[str]] = []

_TOLERATED_EXTENSIONS: ClassVar[set[str]] = {".nxs"}

Expand All @@ -120,14 +128,12 @@ def load_top_level_scan(
) -> xr.Dataset:
"""Reads a spectrum from the top level group in a NeXuS scan format.
[TODO:description]
Args:
group ([TODO:type]): [TODO:description]
scan_desc: [TODO:description]
spectrum_index ([TODO:type]): [TODO:description]
Returns:
Returns (xr.Dataset):
[TODO:description]
"""
if scan_desc:
Expand Down Expand Up @@ -177,7 +183,10 @@ def get_coords(self, group: Incomplete, scan_name: str, shape: Incomplete):
(
name
if set_names[name] == 1
else parse_axis_name_from_long_name(actuator_long_names[i], keep_segments)
else parse_axis_name_from_long_name(
actuator_long_names[i],
keep_segments,
)
)
for i, name in enumerate(actuator_names)
]
Expand Down Expand Up @@ -241,13 +250,17 @@ def take_last(vs):
energy = data[e_keys[0]][0], data[e_keys[1]][0], data[e_keys[2]][0]
angle = data[ang_keys[0]][0], data[ang_keys[1]][0], data[ang_keys[2]][0]

def get_first(item):
def get_first(item: NDArray[np.float_] | float):
if isinstance(item, np.ndarray):
return item.ravel()[0]

return item

def build_axis(low: float, high: float, step_size: float) -> tuple[NDArray[np.float_], int]:
def build_axis(
low: float,
high: float,
step_size: float,
) -> tuple[NDArray[np.float_], int]:
# this might not work out to be the right thing to do, we will see
low, high, step_size = get_first(low), get_first(high), get_first(step_size)
est_n: int = int((high - low) / step_size)
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/BL10_SARPES.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def load_single_region(
"""Loads a single region for multi-region scans."""
from arpes.load_pxt import read_single_pxt

name, _ = Path(region_path).stem
name = Path(region_path).stem
num = name[-3:]

pxt_data = read_single_pxt(region_path, allow_multiple=True)
Expand Down
23 changes: 19 additions & 4 deletions src/arpes/endstations/plugin/Elettra_spectromicroscopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def unwrap_bytestring(
)


class SpectromicroscopyElettraEndstation(HemisphericalEndstation, SynchrotronEndstation):
class SpectromicroscopyElettraEndstation(
HemisphericalEndstation,
SynchrotronEndstation,
):
"""Data loading for the nano-ARPES beamline "Spectromicroscopy Elettra".
Information available on the beamline can be accessed
Expand Down Expand Up @@ -145,7 +148,12 @@ def files_for_search(cls: type, directory: str | Path) -> list[Path]:
else:
base_files = [*base_files, Path(file)]

return list(filter(lambda f: Path(f).suffix in cls._TOLERATED_EXTENSIONS, base_files))
return list(
filter(
lambda f: Path(f).suffix in cls._TOLERATED_EXTENSIONS,
base_files,
)
)

ANALYZER_INFORMATION: ClassVar[dict[str, str | float | bool]] = {
"analyzer": "Custom: in vacuum hemispherical",
Expand Down Expand Up @@ -228,7 +236,10 @@ def concatenate_frames(

return xr.Dataset({"spectrum": xr.concat(fs, scan_coord)})

def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path]:
def resolve_frame_locations(
self,
scan_desc: ScanDesc | None = None,
) -> list[Path]:
"""Determines all files associated with a given scan.
This beamline saves several HDF files in scan associated folders, so this
Expand Down Expand Up @@ -269,7 +280,11 @@ def load_single_frame(

return xr.Dataset(arrays)

def postprocess_final(self, data: xr.Dataset, scan_desc: ScanDesc | None = None) -> xr.Dataset:
def postprocess_final(
self,
data: xr.Dataset,
scan_desc: ScanDesc | None = None,
) -> xr.Dataset:
"""Performs final postprocessing of the data.
This mostly amounts to:
Expand Down
11 changes: 9 additions & 2 deletions src/arpes/endstations/plugin/HERS.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
__all__ = ("HERSEndstation",)


class HERSEndstation(SynchrotronEndstation, HemisphericalEndstation):
class HERSEndstation(
SynchrotronEndstation,
HemisphericalEndstation,
):
"""Implements data loading at the ALS HERS beamline.
This should be unified with the FITs endstation code, but I don't have any projects at BL10
Expand All @@ -33,7 +36,11 @@ class HERSEndstation(SynchrotronEndstation, HemisphericalEndstation):
PRINCIPAL_NAME = "ALS-BL1001"
ALIASES: ClassVar[list[str]] = ["ALS-BL1001", "HERS", "ALS-HERS", "BL1001"]

def load(self, scan_desc: ScanDesc | None = None, **kwargs: Incomplete) -> xr.Dataset:
def load(
self,
scan_desc: ScanDesc | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
"""Loads HERS data from FITS files. Shares a lot in common with Lanzara group formats.
Args:
Expand Down
5 changes: 4 additions & 1 deletion src/arpes/endstations/plugin/IF_UMCS.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
__all__ = ("IF_UMCS",)


class IF_UMCS(HemisphericalEndstation, SingleFileEndstation): # noqa: N801
class IF_UMCS( # noqa: N801
HemisphericalEndstation,
SingleFileEndstation,
):
"""Implements loading xy text files from the Specs Prodigy software."""

PRINCIPAL_NAME = "IF_UMCS"
Expand Down
6 changes: 5 additions & 1 deletion src/arpes/endstations/plugin/MAESTRO.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
__all__ = ("MAESTROMicroARPESEndstation", "MAESTRONanoARPESEndstation")


class MAESTROARPESEndstationBase(SynchrotronEndstation, HemisphericalEndstation, FITSEndstation):
class MAESTROARPESEndstationBase(
SynchrotronEndstation,
HemisphericalEndstation,
FITSEndstation,
):
"""Common code for the MAESTRO ARPES endstations at the Advanced Light Source."""

PRINCIPAL_NAME = ""
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/endstations/plugin/kaindl.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def concatenate_frames(

frames.sort(key=lambda x: x.coords[axis_name])
return xr.concat(frames, axis_name)
except Exception as err:
logger.info(f"Exception occurs. {err=}, {type(err)=}")
except Exception:
logger.exception("Exception occurs.")
return None

def postprocess_final(self, data: xr.Dataset, scan_desc: ScanDesc | None = None) -> xr.Dataset:
Expand Down
4 changes: 1 addition & 3 deletions src/arpes/endstations/plugin/merlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ def load_single_frame(
scan_desc["path"] = frame_path
return self.load_SES_nc(scan_desc=scan_desc, **kwargs)

original_data_loc: Path | str = scan_desc.get("path", scan_desc.get("file"))

p = Path(original_data_loc)
p = Path(scan_desc.get("path", scan_desc.get("file", "")))

# find files with same name stem, indexed in format R###
regions = find_ses_files_associated(p, separator="R")
Expand Down
45 changes: 21 additions & 24 deletions src/arpes/plotting/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def apply_transformations(
def plot_plane_to_bz(
cell: Sequence[Sequence[float]] | NDArray[np.float_],
plane: str | list[NDArray[np.float_]],
ax: Axes,
ax: Axes3D,
special_points: dict[str, NDArray[np.float_]] | None = None,
facecolor: ColorType = "red",
) -> None:
Expand All @@ -209,7 +209,7 @@ def plot_plane_to_bz(
if isinstance(plane, str):
plane_points: list[NDArray[np.float_]] = process_kpath(
plane,
cell,
np.array(cell),
special_points=special_points,
)[0]
else:
Expand All @@ -226,7 +226,7 @@ def plot_plane_to_bz(


def plot_data_to_bz(
data: DataType,
data: xr.DataArray,
cell: Sequence[Sequence[float]] | NDArray[np.float_],
**kwargs: Incomplete,
) -> Path | tuple[Figure, Axes]:
Expand Down Expand Up @@ -313,10 +313,10 @@ def plot_data_to_bz2d( # noqa: PLR0913


def plot_data_to_bz3d(
data: DataType,
data: xr.DataArray,
cell: Sequence[Sequence[float]] | NDArray[np.float_],
**kwargs: Incomplete,
) -> None:
) -> Path | tuple[Figure, Axes]:
"""Plots ARPES data onto a 3D Brillouin zone."""
msg = "plot_data_to_bz3d is not implemented yet."
logger.debug(f"id of data: {data.attrs.get('id', None)}")
Expand Down Expand Up @@ -533,25 +533,19 @@ def draw(self, renderer: Incomplete) -> None:

def annotate_special_paths(
ax: Axes,
paths: list[str] | str,
paths: list[str] | str = "",
cell: NDArray[np.float_] | Sequence[Sequence[float]] | None = None,
offset: dict[str, Sequence[float]] | None = None,
special_points: dict[str, NDArray[np.float_]] | None = None,
labels: Incomplete = None,
**kwargs: Incomplete,
) -> None:
"""Annotates user indicated paths in k-space by plotting lines (or points) over the BZ."""
logger.debug(f"annotate-ax: {ax}")
logger.debug(f"annotate-paths: {paths}")
logger.debug(f"annotate-cell: {cell}")
logger.debug(f"annotate-offset: {offset}")
logger.debug(f"annotate-special_points: {special_points}")
logger.debug(f"annotate-labels: {labels}")
if kwargs:
for k, v in kwargs.items():
logger.debug(f"kwargs: kyes: {k}, value: {v}")

if paths == "":
if not paths:
msg = "Must provide a proper path."
raise ValueError(msg)

Expand Down Expand Up @@ -668,7 +662,7 @@ def twocell_to_bz1(cell: NDArray[np.float_]) -> Incomplete:


def bz2d_plot(
cell: Sequence[Sequence[float]],
cell: Sequence[Sequence[float]] | NDArray[np.float_],
paths: str | list[float] | None = None,
points: Sequence[float] | None = None,
repeat: tuple[int, int] | None = None,
Expand All @@ -687,16 +681,8 @@ def bz2d_plot(
Plots a Brillouin zone corresponding to a given unit cell
"""
logger.debug(f"bz2d_plot-cell: {cell}")
logger.debug(f"bz2d_plot-paths: {paths}")
logger.debug(f"bz2d_plot-points: {points}")
logger.debug(f"bz2d_plot-repeat: {repeat}")
logger.debug(f"bz2d_plot-transformations: {transformations}")
logger.debug(f"bz2d_plot-hide_ax: {hide_ax}")
logger.debug(f"bz2d_plot-vectors: {vectors}")
logger.debug(f"bz2d_plot-set_equal_aspect: {set_equal_aspect}")
kpoints = points
bz1, icell, cell = twocell_to_bz1(cell)
bz1, icell, cell = twocell_to_bz1(np.array(cell))
logger.debug(f"bz1 : {bz1}")
if ax is None:
ax = plt.axes()
Expand All @@ -710,6 +696,12 @@ def bz2d_plot(
path_string = cell_structure.special_path if paths == "all" else paths
paths = []
for names in parse_path_string(path_string):
"""
>>> parse_path_string('GX')
[['G', 'X']]
>>> parse_path_string('GX,M1A')
[['G', 'X'], ['M1', 'A']]
"""
points = []
for name in names:
points.append(np.dot(icell.T, special_points[name]))
Expand Down Expand Up @@ -774,7 +766,12 @@ def bz2d_plot(
)

if paths is not None:
annotate_special_paths(ax, paths, offset=offset, transformations=transformations)
annotate_special_paths(
ax,
paths,
offset=offset,
transformations=transformations,
)

if kpoints is not None:
for p in kpoints:
Expand Down
Loading

0 comments on commit fa7b283

Please sign in to comment.