Skip to content

Commit

Permalink
🔥 remvoe explicit args because they are just passed to matplotlib fun…
Browse files Browse the repository at this point in the history
…ctions

* 💬  update type hints
  • Loading branch information
arafune committed Sep 29, 2023
1 parent 58c3dfc commit b632560
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 136 deletions.
4 changes: 3 additions & 1 deletion arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from _typeshed import Incomplete
from matplotlib.artist import Artist
from matplotlib.backend_bases import Event
from matplotlib.colors import Colormap
from matplotlib.colors import Colormap, Normalize
from matplotlib.figure import Figure
from matplotlib.patheffects import AbstractPathEffect
from matplotlib.transforms import BboxBase, Transform
Expand Down Expand Up @@ -380,6 +380,8 @@ class ColorbarParam(TypedDict, total=False):
boundaries: None | Sequence[float]
values: None | Sequence[float]
location: None | Literal["left", "right", "top", "bottom"]
cmap: Colormap
norm: Normalize


class MPLTextParam(TypedDict, total=False):
Expand Down
69 changes: 37 additions & 32 deletions arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
"""Provides some band analysis tools."""
from __future__ import annotations

import contextlib
import copy
import functools
import itertools
from typing import TYPE_CHECKING, Literal

import lmfit as lf
import numpy as np
import xarray as xr
from numpy.typing import NDArray
from scipy.spatial import distance

import arpes.models.band
import arpes.utilities.math
from arpes._typing import DataType
from arpes.constants import HBAR_SQ_EV_PER_ELECTRON_MASS_ANGSTROM_SQ
from arpes.fits import AffineBackgroundModel, LorentzianModel, QuadraticModel, broadcast_model
from arpes.provenance import update_provenance
from arpes.utilities import enumerate_dataarray, normalize_to_spectrum
from arpes.utilities.conversion.forward import convert_coordinates_to_kspace_forward
from arpes.utilities.jupyter import wrap_tqdm

if TYPE_CHECKING:
from collections.abc import Generator

import lmfit as lf
from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = (
"fit_bands",
"fit_for_effective_mass",
Expand Down Expand Up @@ -49,6 +57,8 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl
fit_kwargs = {}
assert isinstance(fit_kwargs, dict)
data_array = normalize_to_spectrum(data)
assert isinstance(data_array, xr.DataArray)

mom_dim = next(
dim for dim in ["kp", "kx", "ky", "kz", "phi", "beta", "theta"] if dim in data_array.dims
)
Expand All @@ -61,6 +71,7 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl
)
if mom_dim in {"phi", "beta", "theta"}:
forward = convert_coordinates_to_kspace_forward(data_array)
assert isinstance(forward, xr.Dataset)
final_mom = next(dim for dim in ["kx", "ky", "kp", "kz"] if dim in forward)
eVs = results.F.p("a_center").values
kps = [
Expand All @@ -78,8 +89,7 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl
def unpack_bands_from_fit(
band_results: xr.DataArray,
weights: tuple[float, float, float] | tuple[()] = (),
use_stderr_weighting=True,
):
) -> list[arpes.models.band.Band]:
"""This function is used to deconvolve the band identities of a series of overlapping bands.
Sometimes through the fitting process, or across a place in the band structure where there is a
Expand Down Expand Up @@ -108,8 +118,6 @@ def unpack_bands_from_fit(
arr
band_results
weights
use_stderr_weighting: Flag to indicate whether to scale vectors
by the uncertainty
Returns:
Unpacked bands.
Expand All @@ -122,7 +130,7 @@ def unpack_bands_from_fit(

identified_band_results = copy.deepcopy(band_results)

def as_vector(model_fit: lf.Model, prefix="") -> NDArray[np.float_]:
def as_vector(model_fit: lf.Model, prefix: str = "") -> NDArray[np.float_]:
"""[TODO:summary].
[TODO:description]
Expand Down Expand Up @@ -199,17 +207,18 @@ def as_vector(model_fit: lf.Model, prefix="") -> NDArray[np.float_]:
for i in range(len(prefixes)):
label = identified_band_results.loc[first_coordinate].values.item()[i]

def dataarray_for_value(param_name, i: int = i, *, is_value: bool) -> xr.DataArray:
def dataarray_for_value(param_name: str, i: int = i, *, is_value: bool) -> xr.DataArray:
"""[TODO:summary].
[TODO:description]
Args:
param_name ([TODO:type]): [TODO:description]
is_value ([TODO:type]): [TODO:description]
i: [TODO:description]
"""
values = np.ndarray(shape=identified_band_results.values.shape, dtype=float)
values: NDArray[np.float_] = np.ndarray(
shape=identified_band_results.values.shape,
dtype=float,
)
it = np.nditer(values, flags=["multi_index"], op_flags=[["writeonly"]])
while not it.finished:
prefix = identified_band_results.values[it.multi_index][i]
Expand Down Expand Up @@ -245,15 +254,12 @@ def dataarray_for_value(param_name, i: int = i, *, is_value: bool) -> xr.DataArr
def fit_patterned_bands(
arr: xr.DataArray,
band_set,
direction_normal=True,
fit_direction=None,
avoid_crossings=None,
stray=None,
background=True,
preferred_k_direction=None,
interactive=True,
dataset=True,
):
background: bool = True,
interactive: bool = True,
dataset: bool = True,
) -> xr.DataArray | xr.Dataset:
"""Fits bands and determines dispersion in some region of a spectrum.
The dimensions of the dataset are partitioned into three types:
Expand All @@ -277,7 +283,7 @@ def fit_patterned_bands(
orientation: edc or mdc
direction_normal
preferred_k_direction
dataset
dataset: if True, return as Dataset
Returns:
Dataset or DataArray, as controlled by the parameter "dataset"
Expand All @@ -290,7 +296,7 @@ def fit_patterned_bands(
free_directions = list(arr.dims)
free_directions.remove(fit_direction)

def is_between(x, y0, y1):
def is_between(x: float, y0: float, y1: float) -> bool:
y0, y1 = np.min([y0, y1]), np.max([y0, y1])
return y0 <= x <= y1

Expand Down Expand Up @@ -319,7 +325,7 @@ def interpolate_itersecting_fragments(coord, coord_index, points):

def resolve_partial_bands_from_description(
coord_dict,
name=None,
name: str = "",
band=arpes.models.band.Band,
dims=None,
params=None,
Expand Down Expand Up @@ -373,10 +379,7 @@ def build_params(old_params, center, center_stray=None):

low, high = np.percentile(
near_center.values,
(
20,
80,
),
(20, 80),
)
new_params["amplitude"] = new_params.get("amplitude", {})
new_params["amplitude"]["value"] = high - low
Expand All @@ -402,7 +405,7 @@ def build_params(old_params, center, center_stray=None):
total_slices = np.prod([len(arr.coords[d]) for d in free_directions])
for coord_dict, marginal in wrap_tqdm(
arr.G.iterate_axis(free_directions),
interactive,
interactive=interactive,
desc="fitting",
total=total_slices,
):
Expand Down Expand Up @@ -471,8 +474,7 @@ def instantiate_band(partial_band):
def fit_bands(
arr: xr.DataArray,
band_description,
background=None,
direction="mdc",
direction: Literal["edc", "mdc", "EDC", "MDC"] = "mdc",
preferred_k_direction=None,
step=None,
):
Expand All @@ -487,11 +489,14 @@ def fit_bands(
Returns:
Fitted bands.
"""
assert direction in ["edc", "mdc"]
assert direction in ["edc", "mdc", "EDC", "MDC"]

def iterate_marginals(arr: xr.DataArray, iterate_directions=None):
def iterate_marginals(
arr: xr.DataArray,
iterate_directions: list[str] | None = None,
) -> Generator:
if iterate_directions is None:
iterate_directions = list(arr.dims)
iterate_directions = [str(dim) for dim in arr.dims]
iterate_directions.remove("eV")

selectors = itertools.product(*[arr.coords[d] for d in iterate_directions])
Expand Down
38 changes: 19 additions & 19 deletions arpes/analysis/gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
__all__ = ("normalize_by_fermi_dirac", "determine_broadened_fermi_distribution", "symmetrize")


def determine_broadened_fermi_distribution(reference_data: DataType, fixed_temperature=True):
def determine_broadened_fermi_distribution(
reference_data: DataType,
fixed_temperature: bool = True,
):
"""Determine the parameters for broadening by temperature and instrumental resolution.
As a general rule, we first try to estimate the instrumental broadening and linewidth broadening
Expand Down Expand Up @@ -54,14 +57,14 @@ def determine_broadened_fermi_distribution(reference_data: DataType, fixed_tempe
sum_dims = list(reference_data_array.dims)
sum_dims.remove("eV")

return AffineBroadenedFD().guess_fit(reference_data.sum(sum_dims), params=params)
return AffineBroadenedFD().guess_fit(reference_data_array.sum(sum_dims), params=params)


@update_provenance("Normalize By Fermi Dirac")
def normalize_by_fermi_dirac(
data: DataType,
reference_data: DataType | None = None,
plot=False,
plot: bool = False,
broadening=None,
temperature_axis=None,
temp_offset=0,
Expand Down Expand Up @@ -140,9 +143,10 @@ def normalize_by_fermi_dirac(
)
# <== NEED TO CHECK (What it the type of without_background ?)

without_background_arr = normalize_to_spectrum(without_background)
assert isinstance(without_background_arr, xr.DataArray)
if temperature_axis:
without_background = normalize_to_spectrum(without_background)
divided = without_background.G.map_axes(
divided = without_background_arr.G.map_axes(
temperature_axis,
lambda x, coord: x
/ broadening_fit.eval(
Expand All @@ -155,8 +159,7 @@ def normalize_by_fermi_dirac(
),
)
else:
without_background = normalize_to_spectrum(without_background)
divided = without_background / broadening_fit.eval(
divided = without_background_arr / broadening_fit.eval(
x=data.coords["eV"].values,
conv_width=broadening,
lin_bkg=0,
Expand All @@ -170,21 +173,18 @@ def normalize_by_fermi_dirac(
return divided


def _shift_energy_interpolate(data: DataType, shift=None):
if shift is not None:
pass

data = normalize_to_spectrum(data).S.transpose_to_front("eV")
def _shift_energy_interpolate(data: DataType, shift: xr.DataArray | None = None):
data_arr = normalize_to_spectrum(data).S.transpose_to_front("eV")

new_data = data.copy(deep=True)
new_data = data_arr.copy(deep=True)
new_axis = new_data.coords["eV"]
new_values = new_data.values * 0

if shift is None:
closest_to_zero = data.coords["eV"].sel(eV=0, method="nearest")
closest_to_zero = data_arr.coords["eV"].sel(eV=0, method="nearest")
shift = -closest_to_zero

stride = data.G.stride("eV", generic_dim_names=False)
stride = data_arr.G.stride("eV", generic_dim_names=False)

if np.abs(shift) >= stride:
n_strides = int(shift / stride)
Expand All @@ -196,11 +196,11 @@ def _shift_energy_interpolate(data: DataType, shift=None):

weight = float(shift / stride)

new_values = new_values + data.values * (1 - weight)
new_values = new_values + data_arr.values * (1 - weight)
if shift > 0:
new_values[1:] = new_values[1:] + data.values[:-1] * weight
new_values[1:] = new_values[1:] + data_arr.values[:-1] * weight
if shift < 0:
new_values[:-1] = new_values[:-1] + data.values[1:] * weight
new_values[:-1] = new_values[:-1] + data_arr.values[1:] * weight

new_data.coords["eV"] = new_axis
new_data.values = new_values
Expand All @@ -209,7 +209,7 @@ def _shift_energy_interpolate(data: DataType, shift=None):


@update_provenance("Symmetrize")
def symmetrize(data: DataType, subpixel=False, full_spectrum=False):
def symmetrize(data: DataType, subpixel: bool = False, full_spectrum: bool = False):
"""Symmetrizes data across the chemical potential.
This provides a crude tool by which
Expand Down
11 changes: 4 additions & 7 deletions arpes/analysis/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@update_provenance("Discretize Path")
def discretize_path(path: xr.Dataset, n_points=None, scaling=None) -> xr.Dataset:
def discretize_path(path: xr.Dataset, n_points: int = 0, scaling=None) -> xr.Dataset:
"""Discretizes a path into a set of points spaced along the path.
Shares logic with slice_along_path
Expand Down Expand Up @@ -57,11 +57,7 @@ def distance(a, b):
coord_low, coord_high = path.sel(index=idx_low), path.sel(index=idx_high)
length += distance(coord_low, coord_high)

if n_points is None:
# play with this until it seems reasonable
n_points = int(length / 0.03)
else:
n_points = max(n_points - 1, 1)
n_points = int(length / 0.03) if not n_points else max(n_points - 1, 1)

sep = length / n_points
points = []
Expand Down Expand Up @@ -99,7 +95,8 @@ def select_along_path(
path: xr.Dataset,
data: DataType,
radius=None,
n_points=None,
n_points: int = 0,
*,
fast: bool = True,
scaling=None,
**kwargs: Incomplete,
Expand Down
12 changes: 12 additions & 0 deletions arpes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os.path
import warnings
from dataclasses import dataclass, field
from logging import DEBUG, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -28,6 +29,17 @@
from arpes._typing import CONFIGTYPE, ConfigSettings
# pylint: disable=global-statement

LOGLEVEL = DEBUG
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


ureg = pint.UnitRegistry()

Expand Down
Loading

0 comments on commit b632560

Please sign in to comment.