Skip to content

Commit

Permalink
Merge branch 'daredevil' of https://github.com/arafune/arpes into dar…
Browse files Browse the repository at this point in the history
…edevil
  • Loading branch information
arafune committed Oct 23, 2023
2 parents 8539d3d + d5c62dd commit 60602dd
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 87 deletions.
9 changes: 7 additions & 2 deletions arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def from_param(cls: type, model_param: lf.Model.Parameter):
return cls(center=model_param.value, stderr=model_param.stderr)


def propagate_errors(f) -> Callable:
def propagate_errors(f: Callable) -> Callable:
"""A decorator which provides transparent propagation of statistical errors.
The way that this is accommodated is that the inner function is turned into one which
Expand Down Expand Up @@ -332,7 +332,12 @@ def bootstrap(
elif resample_method == "cycle":
resample_fn = resample_cycle

def bootstrapped(*args, n: int = 20, prior_adjustment=1, **kwargs: Incomplete):
def bootstrapped(
*args,
n: int = 20,
prior_adjustment: int = 1,
**kwargs: Incomplete,
):
# examine args to determine which to resample
resample_indices = [
i
Expand Down
30 changes: 22 additions & 8 deletions arpes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,14 @@ def stitch(
Returns:
The concatenated data.
"""
list_of_files = None
if isinstance(df_or_list, pd.DataFrame):
list_of_files = list(df_or_list.index)
else:
if not isinstance(df_or_list, list | tuple):
msg = "Expected an interable for a list of the scans to stitch together"
raise TypeError(msg)
list_of_files = list(df_or_list)
list_of_files = _df_or_list_to_files(df_or_list)
if not built_axis_name:
assert isinstance(attr_or_axis, str)
built_axis_name = attr_or_axis
if not list_of_files:
msg = "Must supply at least one file to stitch"
raise ValueError(msg)
#
loaded = [
f if isinstance(f, xr.DataArray | xr.Dataset) else load_data(f) for f in list_of_files
]
Expand Down Expand Up @@ -221,6 +215,26 @@ def stitch(
return concatenated


def _df_or_list_to_files(
df_or_list: list[str] | pd.DataFrame,
) -> list[str]:
"""Helper function for stitch.
Args:
df_or_list(pd.DataFrame, list): input data file
Returns: (list[str])
list of files to stitch.
"""
if isinstance(df_or_list, pd.DataFrame):
return list(df_or_list.index)
assert not isinstance(
df_or_list,
list | tuple,
), "Expected an interable for a list of the scans to stitch together"
return list(df_or_list)


def file_for_pickle(name: str) -> Path | str:
here = Path()
from arpes.config import CONFIG
Expand Down
2 changes: 1 addition & 1 deletion arpes/models/band.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Rudimentary band analyis code."""
"""Rudimentary band analysis code."""
from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down
4 changes: 2 additions & 2 deletions arpes/plotting/path_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def click_main_image(event):
("Path", "path"),
]

def convert_to_xarray():
def convert_to_xarray() -> dict[str, xr.dataset]:
"""Creates a Dataset consisting of one array for each path.
For each of the paths, we will create a dataset which has an index dimension,
and datavariables for each of the coordinate dimensions
"""

def convert_single_path_to_xarray(points):
def convert_single_path_to_xarray(points) -> xr.Dataset:
vars = {d: np.array([p[i] for p in points]) for i, d in enumerate(self.arr.dims)}
coords = {
"index": np.array(range(len(points))),
Expand Down
8 changes: 4 additions & 4 deletions arpes/preparation/axis_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def normalize(arr: xr.Dataset | xr.DataArray) -> xr.DataArray:

def transform_dataarray_axis(
func: Callable[..., ...],
old_axis_name: str,
new_axis_name: str,
old_and_new_axis_names: tuple[str, str],
new_axis: NDArray[np.float_] | xr.DataArray,
dataset: xr.Dataset,
prep_name: Callable[[str], str],
Expand All @@ -195,14 +194,15 @@ def transform_dataarray_axis(
Args:
func ([TODO:type]): [TODO:description]
old_axis_name(str): [TODO:description]
new_axis_name(str): [TODO:description]
old_and_new_axis_names (tuple[str, str]) : old and new axis names as the tuple form
new_axis ([TODO:type]): [TODO:description]
dataset(xr.Dataset): [TODO:description]
prep_name ([TODO:type]): [TODO:description]
transform_spectra ([TODO:type]): [TODO:description]
remove_old ([TODO:type]): [TODO:description]
"""
old_axis_name, new_axis_name = old_and_new_axis_names

ds = dataset.copy()
if transform_spectra is None:
# transform *all* DataArrays in the dataset that have old_axis_name in their dimensions.
Expand Down
6 changes: 4 additions & 2 deletions arpes/preparation/coord_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
import numpy as np

if TYPE_CHECKING:
from collections.abc import Sequence

import xarray as xr

__all__ = ["disambiguate_coordinates"]


def disambiguate_coordinates(
datasets: xr.Dataset,
possibly_clashing_coordinates,
):
possibly_clashing_coordinates: Sequence[str],
) -> list[xr.DataArray]:
"""Finds and unifies duplicated coordinates or ambiguous coordinates.
This is useful if two regions claim to have an energy axis, but one is a core level
Expand Down
6 changes: 2 additions & 4 deletions arpes/preparation/tof_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,7 @@ def process_SToF(dataset: xr.Dataset) -> xr.Dataset:

dataset = transform_dataarray_axis(
build_KE_coords_to_time_coords(dataset, ke_axis),
"time",
"eV",
("time", "eV"),
ke_axis,
dataset,
lambda x: x,
Expand Down Expand Up @@ -276,8 +275,7 @@ def process_DLD(dataset: xr.Dataset) -> xr.Dataset:
)
return transform_dataarray_axis(
build_KE_coords_to_time_pixel_coords(dataset, ke_axis),
"t_pixels",
"kinetic",
("t_pixels", "kinetic"),
ke_axis,
dataset,
lambda: "kinetic_spectrum",
Expand Down
5 changes: 0 additions & 5 deletions arpes/utilities/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,6 @@ def flat_bz_indices_list(
bz_indices_list = [(0, 0)]

assert len(bz_indices_list[0]) in {2, 3}
try:
if len(bz_indices_list[0]) not in {2, 3}:
raise ValueError
except (ValueError, TypeError):
bz_indices_list = [bz_indices_list]

indices = []
if len(bz_indices_list[0]) == 2: # noqa: PLR2004
Expand Down
13 changes: 10 additions & 3 deletions arpes/utilities/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def segment_contains_point(
return 0 - epsilon < delta.dot(delta_p) / delta.dot(delta) < 1 + epsilon


def polyhedron_intersect_plane(poly_faces, plane_normal, plane_point, epsilon: float = 1e-6):
def polyhedron_intersect_plane(
poly_faces: list[NDArray[np.float_]],
plane_normal: NDArray[np.float_],
plane_point: NDArray[np.float_],
epsilon: float = 1e-6,
) -> ConvexHull:
"""Determines the intersection of a convex polyhedron intersecting a plane.
The polyhedron faces should be given by a list of np.arrays, where each np.array at
Expand All @@ -94,7 +99,7 @@ def polyhedron_intersect_plane(poly_faces, plane_normal, plane_point, epsilon: f
"""
collected_points = []

def add_point(c):
def add_point(c: NDArray[np.float_]) -> None:
already_collected = False
for other in collected_points:
delta = c - other
Expand All @@ -106,7 +111,9 @@ def add_point(c):
collected_points.append(c)

for poly_face in poly_faces:
segments = list(zip(poly_face, np.concatenate([poly_face[1:], [poly_face[0]]])))
segments = list(
zip(poly_face, np.concatenate([poly_face[1:], [poly_face[0]]]), strict=True),
)
for a, b in segments:
intersection = point_plane_intersection(
plane_normal,
Expand Down
25 changes: 6 additions & 19 deletions arpes/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import itertools
import pathlib
import pprint
import warnings
from collections.abc import Sequence
from functools import wraps
Expand All @@ -38,6 +39,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pyperclip
import xarray as xr
from matplotlib import gridspec
from matplotlib.axes import Axes
Expand Down Expand Up @@ -340,7 +342,7 @@ def mask_cmap(self) -> Colormap:
return self._mask_cmap

@property
def mask(self):
def mask(self): # noqa: ANN202
return self._mask

@mask.setter
Expand Down Expand Up @@ -468,10 +470,6 @@ def on_add_new_peak(selection) -> None:
ctx["data"] = data

def on_copy_settings(event: MouseEvent) -> None:
import pprint

import pyperclip

pyperclip.copy(pprint.pformat(compute_parameters()))

copy_settings_button = Button(ax_test, "Copy Settings")
Expand All @@ -489,7 +487,7 @@ def pca_explorer(
*,
transpose_mask: bool = False,
) -> CURRENTCONTEXT:
"""A tool providing PCA decomposition exploration of a dataset.
"""A tool providing PCA (Principal component analysis) decomposition exploration of a dataset.
Args:
pca: The decomposition of the data, the output of an sklearn PCA decomp.
Expand Down Expand Up @@ -517,7 +515,7 @@ def pca_explorer(
}
arpes.config.CONFIG["CURRENT_CONTEXT"] = context

def compute_for_scatter():
def compute_for_scatter() -> tuple[xr.DataArray | xr.Dataset, int]:
for_scatter = pca.copy(deep=True).isel(
**dict([[component_dim, context["selected_components"]]]),
)
Expand Down Expand Up @@ -716,18 +714,7 @@ def compute_offsets() -> dict[str, float]:
return {k: v.val for k, v in sliders.items()}

def on_copy_settings(event: MouseEvent) -> None:
try:
import pprint

import pyperclip

pyperclip.copy(pprint.pformat(compute_offsets()))
except ImportError:
pass
finally:
import pprint

print(pprint.pformat(compute_offsets()))
pyperclip.copy(pprint.pformat(compute_offsets()))

def apply_offsets(event: MouseEvent) -> None:
for name, offset in compute_offsets().items():
Expand Down
73 changes: 36 additions & 37 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def transpose_to_back(self, dim: str) -> xr.DataArray | xr.Dataset:

def select_around_data(
self,
points: dict[str, float] | xr.Dataset,
points: dict[str, xr.DataArray] | xr.Dataset,
radius: dict[str, float] | float | None = None, # radius={"phi": 0.005}
*,
mode: Literal["sum", "mean"] = "sum",
Expand Down Expand Up @@ -432,24 +432,9 @@ def select_around_data(
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}

if isinstance(radius, float):
radius = {str(d): radius for d in points}
else:
collected_terms = {f"{k}_r" for k in points}.intersection(
set(kwargs.keys()),
)
if collected_terms:
radius = {
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}
radius = self._radius(points, radius, **kwargs)

assert isinstance(radius, dict)
radius = {
str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points
}
logger.debug(f"iter(points.values()): {iter(points.values())}")

along_dims = next(iter(points.values())).dims
selected_dims = list(points.keys())
Expand Down Expand Up @@ -538,25 +523,7 @@ def select_around(
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
logger.debug(f"points: {points}")
if isinstance(radius, float):
radius = {str(d): radius for d in points}
else:
collected_terms = {f"{k}_r" for k in points}.intersection(
set(kwargs.keys()),
)
if collected_terms:
radius = {
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.Get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}

assert isinstance(radius, dict)
radius = {
str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points
}

radius = self._radius(points, radius, **kwargs)
logger.debug(f"radius: {radius}")
nearest_sel_params = {}

Expand Down Expand Up @@ -585,6 +552,38 @@ def select_around(
return selected.sum(list(radius.keys()))
return selected.mean(list(radius.keys()))

def _radius(
self,
points: dict[str, float],
radius: float | dict[str, float] | None,
**kwargs: float,
) -> dict[str, float]:
"""Helper function. Generate radius dict.
When radius is dict form, nothing has been done, essentially.
Args:
points (dict[str, float]): Selection point
radius (dict[str, float] | float | None): radius
kwargs (float): [TODO:description]
Returns: dict[str, float]
radius for selection.
"""
if isinstance(radius, float):
radius = {str(d): radius for d in points}
else:
collectted_terms = {f"{k}_r" for k in points}.intersection(set(kwargs.keys()))
if collectted_terms:
radius = {
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}
assert isinstance(radius, dict)
return {str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points}

def short_history(self, key: str = "by") -> list:
"""Return the short version of history.
Expand Down
Loading

0 comments on commit 60602dd

Please sign in to comment.