Skip to content

Commit

Permalink
🎨 Be slim. Don't use dict([[a,b]]), instead use just {a:b}.
Browse files Browse the repository at this point in the history
  💬  Update type hints
  • Loading branch information
arafune committed Feb 6, 2024
1 parent 6aeb7fd commit 22110d4
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 65 deletions.
5 changes: 2 additions & 3 deletions arpes/analysis/background.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provides background estimation approaches."""

from __future__ import annotations

from itertools import pairwise
Expand Down Expand Up @@ -29,9 +30,7 @@ def calculate_background_hull(
dim = arr.dims[0]
processed = []
for blow, bhigh in pairwise(breakpoints):
processed.append(
calculate_background_hull(arr.sel(**dict([[dim, slice(blow, bhigh)]]))),
)
processed.append(calculate_background_hull(arr.sel({dim: slice(blow, bhigh)})))
return xr.concat(processed, dim)

points = np.stack(arr.G.to_arrays(), axis=1)
Expand Down
4 changes: 2 additions & 2 deletions arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def resolve_partial_bands_from_description(
residual.values = np.zeros(residual.shape)

for coords in band_results.G.iter_coords():
fit_item = band_results.sel(**coords).item()
fit_item = band_results.sel(coords).item()
if fit_item is None:
continue

Expand Down Expand Up @@ -591,7 +591,7 @@ def _iterate_marginals(
selectors = itertools.product(*[arr.coords[d] for d in iterate_directions])
for ss in selectors:
coords = dict(zip(iterate_directions, [float(s) for s in ss], strict=True))
yield arr.sel(**coords), coords
yield arr.sel(coords), coords


def _build_params(
Expand Down
4 changes: 2 additions & 2 deletions arpes/analysis/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def decomposition_along(

into = flattened_data.copy(deep=True)
into_first = into.dims[0]
into = into.isel(**dict([[into_first, slice(0, transform.shape[1])]]))
into = into.rename(dict([[into_first, "components"]]))
into = into.isel({into_first: slice(0, transform.shape[1])})
into = into.rename({into_first: "components"})

into.values = transform.T

Expand Down
40 changes: 27 additions & 13 deletions arpes/analysis/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from arpes.provenance import PROVENANCE, provenance

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Hashable

__all__ = (
"gaussian_filter_arr",
Expand All @@ -23,7 +23,7 @@

def gaussian_filter_arr(
arr: xr.DataArray,
sigma: dict[str, float | int] | None = None,
sigma: dict[Hashable, float | int] | None = None,
repeat_n: int = 1,
*,
default_size: int = 1,
Expand All @@ -33,7 +33,8 @@ def gaussian_filter_arr(
Args:
arr(xr.DataArray): ARPES data
sigma: Kernel sigma, specified in terms of axis units (if use_pixel is False).
sigma (dict[Hashable, int]): Kernel sigma, specified in terms of axis units.
(if use_pixel is False).
An axis that is not specified will have a kernel width of `default_size` in index units.
repeat_n: Repeats n times.
default_size: Changes the default kernel width for axes not specified in `sigma`.
Expand All @@ -46,11 +47,14 @@ def gaussian_filter_arr(
"""
if sigma is None:
sigma = {}
sigma = {k: int(v / (arr.coords[k][1] - arr.coords[k][0])) for k, v in sigma.items()}
if use_pixel:
sigma_pixel: dict[Hashable, int] = {k: int(v) for k, v in sigma.items()}
else:
sigma_pixel = {k: int(v / (arr.coords[k][1] - arr.coords[k][0])) for k, v in sigma.items()}
for dim in arr.dims:
if dim not in sigma:
sigma[dim] = default_size
widths_pixel: tuple[int, ...] = tuple(sigma[k] for k in arr.dims)
if dim not in sigma_pixel:
sigma_pixel[dim] = default_size
widths_pixel: tuple[int, ...] = tuple(sigma_pixel[k] for k in arr.dims)
values = arr.values
for _ in range(repeat_n):
values = ndimage.gaussian_filter(values, widths_pixel)
Expand All @@ -70,7 +74,7 @@ def gaussian_filter_arr(

def boxcar_filter_arr(
arr: xr.DataArray,
size: dict[str, int | float] | None = None,
size: dict[Hashable, float] | None = None,
repeat_n: int = 1,
default_size: int = 1,
*,
Expand All @@ -97,8 +101,12 @@ def boxcar_filter_arr(
assert isinstance(arr, xr.DataArray)
if size is None:
size = {}
integered_size = {k: int(v / (arr.coords[k][1] - arr.coords[k][0])) for k, v in size.items()}
del size
if use_pixel:
integered_size: dict[Hashable, int] = {k: int(v) for k, v in size.items()}
else:
integered_size = {
k: int(v / (arr.coords[k][1] - arr.coords[k][0])) for k, v in size.items()
}
for dim in arr.dims:
if dim not in integered_size:
integered_size[str(dim)] = default_size
Expand All @@ -112,15 +120,18 @@ def boxcar_filter_arr(
provenance_context: PROVENANCE = {
"what": "Boxcar filtered data",
"by": "boxcar_filter_arr",
"size": integered_size,
"size": size,
"use_pixel": use_pixel,
}

provenance(filtered_arr, arr, provenance_context)
return filtered_arr


def gaussian_filter(sigma: dict[str, float | int] | None = None, repeat_n: int = 1) -> Callable:
def gaussian_filter(
sigma: dict[Hashable, float | int] | None = None,
repeat_n: int = 1,
) -> Callable[[xr.DataArray], xr.DataArray]:
"""A partial application of `gaussian_filter_arr`.
For further derivative analysis functions.
Expand All @@ -139,7 +150,10 @@ def f(arr: xr.DataArray) -> xr.DataArray:
return f


def boxcar_filter(size: dict[str, int | float] | None = None, repeat_n: int = 1) -> Callable:
def boxcar_filter(
size: dict[Hashable, int | float] | None = None,
repeat_n: int = 1,
) -> Callable[[xr.DataArray], xr.DataArray]:
"""A partial application of `boxcar_filter_arr`.
Output can be passed to derivative analysis functions.
Expand Down
2 changes: 1 addition & 1 deletion arpes/analysis/gap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for gap fitting in ARPES, contains tools to normalize by Fermi-Dirac occupation."""

from __future__ import annotations

import warnings
Expand Down Expand Up @@ -206,7 +207,6 @@ def _shift_energy_interpolate(
new_axis = new_axis + shift

weight = float(shift / stride)

new_values = new_values + data_arr.values * (1 - weight)
if shift > 0:
new_values[1:] = new_values[1:] + data_arr.values[:-1] * weight
Expand Down
10 changes: 7 additions & 3 deletions arpes/analysis/general.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Some general purpose analysis routines otherwise defying categorization."""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -30,7 +31,10 @@


@update_provenance("Fit Fermi Edge")
def fit_fermi_edge(data: DataType, energy_range: slice | None = None) -> xr.Dataset:
def fit_fermi_edge(
data: DataType,
energy_range: slice | None = None,
) -> xr.Dataset:
"""Fits a Fermi edge.
Not much easier than doing it manually, but this can be
Expand Down Expand Up @@ -132,7 +136,7 @@ def symmetrize_axis(

selector = {}
selector[axis_name] = slice(None, None, -1)
rev = data.sel(**selector).copy()
rev = data.sel(selector).copy()

rev.coords[axis_name].values = -rev.coords[axis_name].values

Expand All @@ -142,7 +146,7 @@ def symmetrize_axis(
for axis in flip_axes:
selector = {}
selector[axis] = slice(None, None, -1)
rev = rev.sel(**selector)
rev = rev.sel(selector)
rev.coords[axis].values = -rev.coords[axis].values

return rev.combine_first(data)
Expand Down
4 changes: 2 additions & 2 deletions arpes/analysis/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def apply_mask_to_coords(
@update_provenance("Apply boolean mask to data")
def apply_mask(
data: DataType,
mask,
replace=np.nan,
mask: dict[str, Incomplete],
replace: float = np.nan,
radius=None,
*,
invert: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions arpes/analysis/pocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains electron/hole pocket analysis routines."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -56,7 +57,7 @@ def pocket_parameters(
if sel is None:
sel = {"eV": slice(-0.03, 0.05)}

kfs = [kf_method(s if sel is None else s.sel(**sel), **(method_kwargs or {})) for s in slices]
kfs = [kf_method(s if sel is None else s.sel(sel), **(method_kwargs or {})) for s in slices]

fs_dims = list(data.dims)
if "eV" in fs_dims:
Expand Down Expand Up @@ -300,7 +301,7 @@ def edcs_along_pocket(
if sel is None:
sel = {"eV": slice(-0.05, 0.05)}

kfs = [kf_method(s if sel is None else s.sel(**sel), **(method_kwargs or {})) for s in slices]
kfs = [kf_method(s if sel is None else s.sel(sel), **(method_kwargs or {})) for s in slices]

fs_dims = list(data.dims)
if "eV" in fs_dims:
Expand Down
12 changes: 8 additions & 4 deletions arpes/analysis/savitzky_golay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Scipy cookbook implementations of the Savitzky Golay filter for xr.DataArrays."""

from __future__ import annotations

from math import factorial
Expand All @@ -8,9 +9,12 @@
import scipy.signal
import xarray as xr

from arpes.constants import TWO_DIMENSION
from arpes.provenance import update_provenance

if TYPE_CHECKING:
from collections.abc import Hashable

from numpy.typing import NDArray


Expand All @@ -22,9 +26,9 @@ def savitzky_golay( # noqa: PLR0913
data: xr.DataArray,
window_size: int,
order: int,
deriv: int = 0,
deriv: int | Literal["col", "row", "both", None] = 0,
rate: int = 1,
dim: str = "",
dim: Hashable = "",
) -> xr.DataArray:
"""Implements a Savitzky Golay filter with given window size.
Expand Down Expand Up @@ -58,7 +62,7 @@ def savitzky_golay( # noqa: PLR0913
if deriv == 0:
deriv = None

if len(data.dims) == 3: # noqa: PLR2004
if len(data.dims) == TWO_DIMENSION + 1:
if not dim:
dim = data.dims[-1]
return data.G.map_axes(
Expand All @@ -72,7 +76,7 @@ def savitzky_golay( # noqa: PLR0913
),
)

if len(data.dims) == 2: # noqa: PLR2004
if len(data.dims) == TWO_DIMENSION:
if not dim:
transformed_data = savitzky_golay_2d(
data.values,
Expand Down
3 changes: 2 additions & 1 deletion arpes/analysis/self_energy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains self-energy analysis routines."""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, TypeAlias
Expand Down Expand Up @@ -127,7 +128,7 @@ def estimate_bare_band(
inlier_data = centers.where(
xr.DataArray(
inliers,
coords=dict([[fit_dimension, centers.coords[fit_dimension]]]),
coords={fit_dimension: centers.coords[fit_dimension]},
dims=[fit_dimension],
),
drop=True,
Expand Down
4 changes: 2 additions & 2 deletions arpes/plotting/dynamic_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from collections.abc import Callable

from _typeshed import Incomplete
from PySide6.QtWidgets import QLayout, QWidget
from PySide6.QtWidgets import QGridLayout, QWidget

from arpes._typing import DataType

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(

super().__init__()

def layout(self) -> QLayout:
def layout(self) -> QGridLayout:
return self.main_layout

def configure_image_widgets(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion arpes/plotting/fit_tool/fit_inspection_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import numpy as np
import pyqtgraph as pg
import xarray as xr
import xarray_extensions # noqa: F401
from PySide6 import QtCore, QtWidgets
from PySide6.QtWidgets import QGridLayout, QWidget

import arpes.xarray_extensions # noqa: F401
from arpes.utilities.qt import qt_info
from arpes.utilities.qt.data_array_image_view import DataArrayPlot

Expand Down
6 changes: 3 additions & 3 deletions arpes/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from datetime import UTC
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict
from typing import TYPE_CHECKING, Any, Hashable, TypedDict

import xarray as xr

Expand Down Expand Up @@ -68,8 +68,8 @@ class PROVENANCE(TypedDict, total=False):
axis: str # derivative.dn_along_axis
order: int # derivative.dn_along_axis
#
sigma: dict[str, float] # analysis.filters
size: dict[str, int] # analysis.filters
sigma: dict[Hashable, float] # analysis.filters
size: dict[Hashable, float] # analysis.filters
use_pixel: bool # analysis.filters
#
correction: list[NDArray[np.float_]] # fermi_edge_correction
Expand Down
Loading

0 comments on commit 22110d4

Please sign in to comment.