Skip to content

Commit

Permalink
🔨 stop using cm.get_cmap function because it is deprecated.
Browse files Browse the repository at this point in the history
💬  update type hints
  • Loading branch information
arafune committed Sep 28, 2023
1 parent 5d749af commit 0bf8c1f
Show file tree
Hide file tree
Showing 41 changed files with 598 additions and 331 deletions.
3 changes: 3 additions & 0 deletions Changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Major Changes from 3.0.1
- Remove arpes.all
- Certainly, this it is indeed a lazy and carefree approach, but it's too rough method that leads to a bugs and does not mathc the current pythonic style.

- Remove overlapped_stack_dispersion_plot
- use stack_dispersion_plot with appropriate args

Fix from 3.0.1

- bug of concatenating in broadcast_model
Expand Down
139 changes: 139 additions & 0 deletions arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,142 @@ class ColorbarParam(TypedDict, total=False):
boundaries: None | Sequence[float]
values: None | Sequence[float]
location: None | Literal["left", "right", "top", "bottom"]


class MPLTextParam(TypedDict, total=False):
agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]]
alpha: float | None
animated: bool
antialiased: bool
backgroundcolor: ColorType
color: ColorType
c: ColorType
figure: Figure
fontfamily: str
family: str
fontname: str
fontproperties: str | Path
font: str | Path
font_properties: str | Path
fontsize: float | Literal[
"xx-small",
"x-small",
"small",
"medium",
"large",
"x-large",
"xx-large",
]
size: float | Literal[
"xx-small",
"x-small",
"small",
"medium",
"large",
"x-large",
"xx-large",
]
fontstretch: float | Literal[
"ultra-condensed",
"extra-condensed",
"condensed",
"semi-condensed",
"normal",
"semi-expanded",
"expanded",
"extra-expanded",
"ultra-expanded",
]
stretch: float | Literal[
"ultra-condensed",
"extra-condensed",
"condensed",
"semi-condensed",
"normal",
"semi-expanded",
"expanded",
"extra-expanded",
"ultra-expanded",
]
fontstyle: Literal["normal", "italic", "oblique"]
style: Literal["normal", "italic", "oblique"]
fontvariant: Literal["normal", "small-caps"]
variant: Literal["normal", "small-caps"]
fontweight: float | Literal[
"ultralight",
"light",
"normal",
"regular",
"book",
"medium",
"roman",
"semibold",
"demibold",
"demi",
"bold",
"heavy",
"extra bold",
"black",
]
weight: float | Literal[
"ultralight",
"light",
"normal",
"regular",
"book",
"medium",
"roman",
"semibold",
"demibold",
"demi",
"bold",
"heavy",
"extra bold",
"black",
]
gid: str
horizontalalignment: Literal["left", "center", "right"]
ha: Literal["left", "center", "right"]
in_layout: bool
label: str
linespacing: float
math_fontfamily: str
mouseover: bool
multialignment: Literal["left", "right", "center"]
ma: Literal["left", "right", "center"]
parse_math: bool
path_effects: list[AbstractPathEffect]
picker: None | bool | float | Callable
position: tuple[float, float]
rasterized: bool
rotation: float | Literal["vertical", "horizontal"]
rotation_mode: Literal[None, "default", "anchor"]
sketch_params: tuple[float, float, float]
scale: float
length: float
randomness: float
snap: bool | None
text: str
transform: Transform
transform_rotates_text: bool
url: str
usetex: bool | None
verticalalignment: Literal["bottom", "baseline", "center", "center_baseline", "top"]
va: Literal["bottom", "baseline", "center", "center_baseline", "top"]
visible: bool
wrap: bool
zorder: float


class PLTSubplotParam(TypedDict, total=False):
sharex: bool | Literal["none", "all", "row", "col"]
sharey: bool | Literal["none", "all", "row", "col"]
squeeze: bool
width_ratios: Sequence[float] | None
height_ratios: Sequence[float] | None
subplot_kw: dict
gridspec_kw: dict


class IMshowParam(TypedDict, total=False):
pass
22 changes: 18 additions & 4 deletions arpes/analysis/mask.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
"""Utilities for applying masks to data."""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import xarray as xr
from matplotlib.path import Path

from arpes._typing import DataType
from arpes.provenance import update_provenance
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
import xarray as xr
from _typeshed import Incomplete

from arpes._typing import DataType

__all__ = (
"polys_to_mask",
"apply_mask",
Expand All @@ -15,7 +23,7 @@
)


def raw_poly_to_mask(poly) -> dict:
def raw_poly_to_mask(poly: Incomplete) -> dict[str, Incomplete]:
"""Converts a polygon into a mask definition.
There's not currently much metadata attached to masks, but this is
Expand All @@ -37,7 +45,13 @@ def raw_poly_to_mask(poly) -> dict:
}


def polys_to_mask(mask_dict, coords, shape, radius=None, invert=False):
def polys_to_mask(
mask_dict: dict[str, Incomplete],
coords,
shape,
radius=None,
invert=False,
) -> NDArray[np.float_] | NDArray[np.bool_]:
"""Converts a mask definition in terms of the underlying polygon to a True/False mask array.
Uses the coordinates and shape of the target data in order to determine which pixels
Expand Down
10 changes: 5 additions & 5 deletions arpes/deep_learning/models/regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Very simple regression baselines."""

import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn, optim
from torch.nn import functional

__all__ = ["BaselineRegression", "LinearRegression"]

Expand All @@ -17,7 +17,7 @@ def __init__(self) -> None:
"""Generate network components and use the mean squared error loss."""
super().__init__()
self.linear = nn.Linear(self.input_dimensions, self.output_dimensions)
self.criterion = F.mse_loss
self.criterion = functional.mse_loss

def forward(self, x):
"""Calculate the model output for the minibatch `x`."""
Expand Down Expand Up @@ -53,13 +53,13 @@ def __init__(self) -> None:
self.l1 = nn.Linear(self.input_dimensions, 256)
self.l2 = nn.Linear(256, 128)
self.l3 = nn.Linear(128, self.output_dimensions)
self.criterion = F.mse_loss
self.criterion = functional.mse_loss

def forward(self, x):
"""Calculate the model output for the minibatch `x`."""
flat_x = x.view(x.size(0), -1)
h1 = F.relu(self.l1(flat_x))
h2 = F.relu(self.l2(h1))
h1 = functional.relu(self.l1(flat_x))
h2 = functional.relu(self.l2(h1))
return self.l3(h2)

def training_step(self, batch, batch_index):
Expand Down
19 changes: 12 additions & 7 deletions arpes/fits/broadcast_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
import operator
import warnings
from string import ascii_lowercase
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import lmfit
import xarray as xr

if TYPE_CHECKING:
from arpes.fits.fit_models.x_model_mixin import XModelMixin
from collections.abc import Iterable

from _typeshed import Incomplete

def unwrap_params(params, iter_coordinate):

def unwrap_params(params: dict[str, Any], iter_coordinate: Incomplete) -> dict[str, Any]:
"""Inspects arraylike parameters and extracts appropriate value for current fit."""

def transform_or_walk(v):
def transform_or_walk(v: dict | xr.DataArray | Iterable[float]):
if isinstance(v, dict):
return unwrap_params(v, iter_coordinate)

Expand Down Expand Up @@ -48,7 +50,7 @@ def apply_window(data: xr.DataArray, cut_coords: dict[str, float | slice], windo
return cut_data, original_cut_data


def _parens_to_nested(items):
def _parens_to_nested(items: list) -> list:
"""Turns a flat list with parentheses tokens into a nested list."""
parens = [
(
Expand All @@ -72,7 +74,9 @@ def _parens_to_nested(items):
return items


def reduce_model_with_operators(models: tuple | list[XModelMixin]) -> XModelMixin:
def reduce_model_with_operators(
models: tuple[Incomplete, ...] | list[Incomplete],
) -> Incomplete:
"""Combine models according to mathematical operators."""
if isinstance(models, tuple):
return models[0](prefix=f"{models[1]}_", nan_policy="omit")
Expand All @@ -82,7 +86,8 @@ def reduce_model_with_operators(models: tuple | list[XModelMixin]) -> XModelMixi

left, op, right = models[0], models[1], models[2:]
left, right = reduce_model_with_operators(left), reduce_model_with_operators(right)

assert left is not None
assert right is not None
if op == "+":
return left + right
if op == "*":
Expand Down
21 changes: 11 additions & 10 deletions arpes/fits/lmfit_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack

import matplotlib.pyplot as plt
import xarray as xr
from lmfit import model

if TYPE_CHECKING:
import numpy as np
from _typeshed import Incomplete
from matplotlib.figure import Figure
from numpy.typing import NDArray
from typing_extensions import Unpack

original_plot = model.ModelResult.plot

Expand All @@ -31,12 +31,12 @@ class ModelResultPlotKwargs(TypedDict, total=False):
yerr: NDArray[np.float_]
numpoints: int
fig: Figure
data_kws: dict
fit_kws: dict
init_kws: dict
ax_res_kws: dict
ax_fit_kws: dict
fig_kws: dict
data_kws: dict[str, Incomplete]
fit_kws: dict[str, Incomplete]
init_kws: dict[str, Incomplete]
ax_res_kws: dict[str, Incomplete]
ax_fit_kws: dict[str, Incomplete]
fig_kws: dict[str, Incomplete]
show_init: bool
parse_complex: Literal["abs", "real", "imag", "angle"]
title: str
Expand All @@ -51,8 +51,9 @@ def transform_lmfit_titles(label: str = "", *, is_title: bool = False) -> str:


def patched_plot(
self: Any, **kwargs: Unpack[ModelResultPlotKwargs]
) -> Figure | Literal[False]: # noqa: ANN401
self: Incomplete,
**kwargs: Unpack[ModelResultPlotKwargs],
) -> Figure | Literal[False]:
"""A patch for `lmfit` summary plots in PyARPES.
Scientists like to have LaTeX in their plots,
Expand Down
12 changes: 0 additions & 12 deletions arpes/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import numpy as np

__all__ = (
"waist",
"waist_R",
"rayleigh_range",
"lens_transfer",
"magnification",
Expand All @@ -28,16 +26,6 @@
)


def waist(wavelength: float, z: float, z_R: float) -> float:
"""Calculates the waist size from the measurements at a distance from the waist."""
raise NotImplementedError


def waist_R(waist_0: float, m_squared: float = 1.0) -> float:
"""Calculates the width of t he beam a distance from the waist."""
raise NotImplementedError


def waist_from_rr(wavelength: float, rayleigh_rng: float) -> float:
"""Calculates the waist parameters from the Rayleigh range."""
return np.sqrt((wavelength * rayleigh_rng) / np.pi)
Expand Down
Loading

0 comments on commit 0bf8c1f

Please sign in to comment.