From fb7b00f0e83d84a21088eb9b54f2f8b5d72c3d41 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Mon, 18 Mar 2024 18:13:49 +0900 Subject: [PATCH 01/10] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- src/arpes/utilities/jupyter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e707d81e..cee87c78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ lint.ignore = [ "NPY201", # Numpy 2.0, ] lint.select = ["ALL"] -target-version = "py310" +target-version = "py312" line-length = 100 indent-width = 4 diff --git a/src/arpes/utilities/jupyter.py b/src/arpes/utilities/jupyter.py index 52684ae5..cccf956b 100644 --- a/src/arpes/utilities/jupyter.py +++ b/src/arpes/utilities/jupyter.py @@ -118,7 +118,7 @@ def get_full_notebook_information() -> NoteBookInfomation | None: if not url.startswith(("http:", "https:")): msg = "URL must start with 'http:' or 'https:'" raise ValueError(msg) - sessions = json.load(urllib.request.urlopen(url)) + sessions = json.load(urllib.request.urlopen(url)) # noqa: S310 for sess in sessions: if sess["kernel"]["id"] == kernel_id: return { From 115bcf068ee3052374273faa0a39f9a26e72bd8b Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 10:04:16 +0900 Subject: [PATCH 02/10] =?UTF-8?q?=F0=9F=92=AC=20=20Add=20Incomplete=20to?= =?UTF-8?q?=20type=20unknown=20args.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- src/arpes/analysis/band_analysis.py | 11 ++++++++--- src/arpes/analysis/mask.py | 14 +++++++------- src/arpes/bootstrap.py | 4 ++-- .../corrections/fermi_edge_corrections.py | 2 +- src/arpes/deep_learning/interpret.py | 2 +- src/arpes/fits/broadcast_common.py | 2 +- src/arpes/models/band.py | 5 +++-- src/arpes/plotting/bz.py | 10 +++++----- .../bz_tool/RangeOrSingleValueWidget.py | 4 ++-- src/arpes/plotting/dynamic_tool.py | 4 ++-- src/arpes/plotting/utils.py | 10 +++++----- src/arpes/utilities/bz.py | 12 +++++++----- src/arpes/utilities/jupyter.py | 2 +- src/arpes/utilities/ui.py | 14 +++++++------- src/arpes/widgets.py | 18 +++++++++--------- 16 files changed, 62 insertions(+), 54 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 18c3c64d..7db833d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ lint.ignore = [ "NPY201", # Numpy 2.0, ] lint.select = ["ALL"] -target-version = "py310" +target-version = "py311" line-length = 100 exclude = ["scripts", "docs", "conda"] diff --git a/src/arpes/analysis/band_analysis.py b/src/arpes/analysis/band_analysis.py index d712ba6e..280e5e29 100644 --- a/src/arpes/analysis/band_analysis.py +++ b/src/arpes/analysis/band_analysis.py @@ -435,8 +435,9 @@ def fit_bands( Args: arr(xr.DataArray): band_description: A description of the bands to fit in the region - background - direction + background: + direction: + step: Returns: Fitted bands. @@ -543,7 +544,11 @@ def fit_bands( return band_results, unpacked_bands, residual # Memo bunt_result is xr.DataArray -def _interpolate_intersecting_fragments(coord, coord_index, points): +def _interpolate_intersecting_fragments( + coord: Incomplete, + coord_index: int, + points: Incomplete, +) -> Incomplete: """Finds all consecutive pairs of points in `points`. [TODO:description] diff --git a/src/arpes/analysis/mask.py b/src/arpes/analysis/mask.py index 65e77002..b1597b56 100644 --- a/src/arpes/analysis/mask.py +++ b/src/arpes/analysis/mask.py @@ -50,8 +50,8 @@ def raw_poly_to_mask(poly: Incomplete) -> dict[str, Incomplete]: def polys_to_mask( mask_dict: dict[str, Incomplete], - coords, - shape, + coords: Incomplete, + shape: Incomplete, radius: float = 0, *, invert: bool = False, @@ -66,9 +66,9 @@ def polys_to_mask( waypoints are given in unitful values rather than index values. Args: - mask_dict - coords - shape + mask_dict (TODO) + coords: + shape: radius (float): Additional margin on the path in coordinates of *points*. invert (bool): @@ -108,7 +108,7 @@ def apply_mask_to_coords( dims: list[str], *, invert: bool = True, -): +) -> Incomplete: """Performs broadcasted masking along a given dimension. Args: @@ -139,7 +139,7 @@ def apply_mask( data: xr.DataArray, mask: dict[str, Incomplete], replace: float = np.nan, - radius=None, + radius: Incomplete = None, *, invert: bool = False, ) -> xr.DataArray: diff --git a/src/arpes/bootstrap.py b/src/arpes/bootstrap.py index 50778a71..f231e0e8 100644 --- a/src/arpes/bootstrap.py +++ b/src/arpes/bootstrap.py @@ -249,7 +249,7 @@ def draw_samples(self, n_samples: int = Distribution.DEFAULT_N_SAMPLES) -> NDArr return scipy.stats.norm.rvs(self.center, scale=self.stderr, size=n_samples) @classmethod - def from_param(cls: type, model_param: lf.Model.Parameter): + def from_param(cls: type, model_param: lf.Model.Parameter) -> Incomplete: """Generates a Normal from an `lmfit.Parameter`.""" return cls(center=model_param.value, stderr=model_param.stderr) @@ -353,7 +353,7 @@ def bootstrapped( n: int = 20, prior_adjustment: int = 1, **kwargs: Incomplete, - ): + ) -> Incomplete: # examine args to determine which to resample resample_indices = [ i diff --git a/src/arpes/corrections/fermi_edge_corrections.py b/src/arpes/corrections/fermi_edge_corrections.py index 4f166fa4..b95054ad 100644 --- a/src/arpes/corrections/fermi_edge_corrections.py +++ b/src/arpes/corrections/fermi_edge_corrections.py @@ -167,7 +167,7 @@ def build_direct_fermi_edge_correction( others = [d for d in arr.dims if d not in exclude_axes] edge_fit = broadcast_model(GStepBModel, arr.sum(others).sel(eV=energy_range), along).results - def sieve(_, v) -> bool: + def sieve(_: Incomplete, v: Incomplete) -> bool: return v.item().params["center"].stderr < 0.001 # noqa: PLR2004 corrections = edge_fit.G.filter_coord(along, sieve).G.map( diff --git a/src/arpes/deep_learning/interpret.py b/src/arpes/deep_learning/interpret.py index 42882654..92db5d62 100644 --- a/src/arpes/deep_learning/interpret.py +++ b/src/arpes/deep_learning/interpret.py @@ -137,7 +137,7 @@ def items(self) -> list[InterpretationItem]: def top_losses(self, *, ascending: bool = False) -> list[InterpretationItem]: """Orders the items by loss.""" - def key(item): + def key(item: Incomplete) -> Incomplete: return item.loss if ascending else -item.loss return sorted(self.items, key=key) diff --git a/src/arpes/fits/broadcast_common.py b/src/arpes/fits/broadcast_common.py index f7a928d1..76f6e6f6 100644 --- a/src/arpes/fits/broadcast_common.py +++ b/src/arpes/fits/broadcast_common.py @@ -38,7 +38,7 @@ def unwrap_params( def transform_or_walk( v: dict | xr.DataArray | Iterable[float], - ): + ) -> Incomplete: """[TODO:summary]. [TODO:description] diff --git a/src/arpes/models/band.py b/src/arpes/models/band.py index fb6c85df..54cc3187 100644 --- a/src/arpes/models/band.py +++ b/src/arpes/models/band.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: import lmfit as lf + from _typeshed import Incomplete from numpy.typing import NDArray from arpes._typing import DataType @@ -144,7 +145,7 @@ def amplitude(self) -> xr.DataArray: return self.get_dataarray("amplitude", clean=True) @property - def indexes(self): + def indexes(self) -> Incomplete: """Fetches the indices of the originating data (after fit reduction).""" assert isinstance(self._data, xr.DataArray | xr.Dataset) return self._data.center.indexes @@ -165,7 +166,7 @@ def dims(self) -> tuple[str, ...]: class MultifitBand(Band): """Convenience class that reimplements reading data out of a composite fit result.""" - def get_dataarray(self, var_name: str): + def get_dataarray(self, var_name: str) -> Incomplete: """Converts the underlying data into an array representation.""" assert isinstance(self._data, xr.DataArray | xr.Dataset) full_var_name = self.label + var_name diff --git a/src/arpes/plotting/bz.py b/src/arpes/plotting/bz.py index c5f23d03..bbd91518 100644 --- a/src/arpes/plotting/bz.py +++ b/src/arpes/plotting/bz.py @@ -230,7 +230,7 @@ def plot_data_to_bz( data: DataType, cell: Sequence[Sequence[float]] | NDArray[np.float_], **kwargs: Incomplete, -): +) -> Path | tuple[Figure, Axes]: """A dimension agnostic tool used to plot ARPES data onto a Brillouin zone.""" if len(data) == TWO_DIMENSION + 1: return plot_data_to_bz3d(data, cell, **kwargs) @@ -368,7 +368,7 @@ def bz3d_plot( stacklevel=2, ) msg = "You will need to install ASE before using Brillouin Zone plotting" - raise ImportError(msg) + logger.exception(msg) class Arrow3D(FancyArrowPatch): def __init__( @@ -382,7 +382,7 @@ def __init__( FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs - def draw(self, renderer) -> None: + def draw(self, renderer: Incomplete) -> None: xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) @@ -537,7 +537,7 @@ def annotate_special_paths( cell: NDArray[np.float_] | Sequence[Sequence[float]] | None = None, offset: dict[str, Sequence[float]] | None = None, special_points: dict[str, NDArray[np.float_]] | None = None, - labels=None, + labels: Incomplete = None, **kwargs: Incomplete, ) -> None: """Annotates user indicated paths in k-space by plotting lines (or points) over the BZ.""" @@ -650,7 +650,7 @@ def bz2d_segments( return segments_x, segments_y -def twocell_to_bz1(cell: NDArray[np.float_]): +def twocell_to_bz1(cell: NDArray[np.float_]) -> Incomplete: from ase.dft.bz import bz_vertices # 2d in x-y plane diff --git a/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py b/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py index e2d12b74..af46de2f 100644 --- a/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py +++ b/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py @@ -61,10 +61,10 @@ def __init__( self.recompute() - def mode_changed(self, event, source) -> None: + def mode_changed(self, event: Incomplete, source: Incomplete) -> None: """Unused, currently.""" - def value_changed(self, event, source) -> None: + def value_changed(self, event: Incomplete, source: Incomplete) -> None: """Responds to changes in the internal value.""" if self._prevent_change_events: return diff --git a/src/arpes/plotting/dynamic_tool.py b/src/arpes/plotting/dynamic_tool.py index 17aada75..62a2e786 100644 --- a/src/arpes/plotting/dynamic_tool.py +++ b/src/arpes/plotting/dynamic_tool.py @@ -87,8 +87,8 @@ def update_data(self) -> None: try: mapped_data = self._function(self.data, **self.current_arguments) self.views["f(xy)"].setImage(mapped_data.fillna(0)) - except Exception as err: - logger.debug(f"Exception occurs. {err=}, {type(err)=}") + except Exception: + logger.exception("Exception occurs.") def add_controls(self) -> None: specification = self.calculate_control_specification() diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index 94409cf1..3a002ff7 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -1102,8 +1102,8 @@ def remove_colorbars(fig: Figure | None = None) -> None: ax.remove() else: remove_colorbars(plt.gcf()) - except Exception as err: - logger.debug(f"Exception occurs: {err=}, {type(err)=}") + except Exception: + logger.exception("Exception occurs") def generic_colorbarmap_for_data( @@ -1168,7 +1168,7 @@ class AnchoredHScaleBar(mpl.offsetbox.AnchoredOffsetbox): as alternate to the one provided through matplotlib. """ - def __init__( + def __init__( # noqa: PLR0913 self, size: float = 1, extent: float = 0.03, @@ -1235,14 +1235,14 @@ def load_data_for_figure(p: str | Path) -> None: raise ValueError(msg) with Path(pickle_file).open("rb") as f: - return pickle.load(f) + return pickle.load(f) # noqa: S301 def savefig( desired_path: str | Path, dpi: int = 400, data: list[XrTypes] | tuple[XrTypes, ...] | set[XrTypes] | None = None, - save_data=None, + save_data: Incomplete = None, *, paper: bool = False, **kwargs: Incomplete, diff --git a/src/arpes/utilities/bz.py b/src/arpes/utilities/bz.py index f6bbc6d3..610abc42 100644 --- a/src/arpes/utilities/bz.py +++ b/src/arpes/utilities/bz.py @@ -12,7 +12,7 @@ import itertools import re from collections import Counter -from typing import TYPE_CHECKING, Literal, NamedTuple +from typing import TYPE_CHECKING, Literal, NamedTuple, TypeVar import matplotlib.path import numpy as np @@ -57,6 +57,8 @@ "hex": {"G", "X", "BX"}, } +T = TypeVar("T") + class SpecialPoint(NamedTuple): name: str @@ -400,7 +402,7 @@ def build_2dbz_poly( return raw_poly_to_mask(points_2d) -def bz_symmetry(flat_symmetry_points) -> Literal["rect", "square", "hex"] | None: +def bz_symmetry(flat_symmetry_points: Incomplete) -> Literal["rect", "square", "hex"] | None: """Determines symmetry from a list of the symmetry points. Args: @@ -701,7 +703,7 @@ def reduced_bz_selection(data: DataType) -> DataType: return data -def bz_cutter(symmetry_points, *, reduced: bool = True): +def bz_cutter(symmetry_points: Incomplete, *, reduced: bool = True) -> Incomplete: """Cuts data so that it areas outside the Brillouin zone are masked away. Args: @@ -711,7 +713,7 @@ def bz_cutter(symmetry_points, *, reduced: bool = True): TODO: UNFINISHED, Test """ - def build_bz_mask(data) -> None: + def build_bz_mask(data: Incomplete) -> None: """[TODO:summary]. Args: @@ -721,7 +723,7 @@ def build_bz_mask(data) -> None: [TODO:description] """ - def cutter(data, cut_value: float = np.nan): + def cutter(data: Incomplete, cut_value: float = np.nan) -> Incomplete: """[TODO:summary]. Args: diff --git a/src/arpes/utilities/jupyter.py b/src/arpes/utilities/jupyter.py index 52684ae5..cccf956b 100644 --- a/src/arpes/utilities/jupyter.py +++ b/src/arpes/utilities/jupyter.py @@ -118,7 +118,7 @@ def get_full_notebook_information() -> NoteBookInfomation | None: if not url.startswith(("http:", "https:")): msg = "URL must start with 'http:' or 'https:'" raise ValueError(msg) - sessions = json.load(urllib.request.urlopen(url)) + sessions = json.load(urllib.request.urlopen(url)) # noqa: S310 for sess in sessions: if sess["kernel"]["id"] == kernel_id: return { diff --git a/src/arpes/utilities/ui.py b/src/arpes/utilities/ui.py index d41d638f..8cce7847 100644 --- a/src/arpes/utilities/ui.py +++ b/src/arpes/utilities/ui.py @@ -465,7 +465,7 @@ def _wrap_text(str_or_widget: str | QLabel) -> QLabel: return label(str_or_widget) if isinstance(str_or_widget, str) else str_or_widget -def _unwrap_subject(subject_or_widget): +def _unwrap_subject(subject_or_widget: Incomplete) -> Incomplete: try: return subject_or_widget.subject except AttributeError: @@ -495,7 +495,7 @@ def submit(gate: str, keys: list[str], ui: dict[str, QWidget]) -> rx.Observable: ) -def _try_unwrap_value(v): +def _try_unwrap_value(v: Incomplete) -> Incomplete: try: return v.value except AttributeError: @@ -590,16 +590,16 @@ def bind_dataclass(dataclass_instance: Incomplete, prefix: str, ui: dict[str, QW ) inverse_mapping = {v: k for k, v in forward_mapping.items()} - def extract_field(v): + def extract_field(v: Incomplete) -> Incomplete: try: return v.value except AttributeError: return v - def translate_to_field(x): + def translate_to_field(x: Incomplete) -> Incomplete: return forward_mapping[x] - def translate_from_field(x): + def translate_from_field(x: Incomplete) -> Incomplete: return inverse_mapping[extract_field(x)] current_value = translate_from_field(getattr(dataclass_instance, field_name)) @@ -609,8 +609,8 @@ def translate_from_field(x): w.subject.on_next(current_value) # close over the translation function - def build_setter(translate, name): - def setter(value) -> None: + def build_setter(translate: Incomplete, name: Incomplete) -> Callable[..., None]: + def setter(value: Incomplete) -> None: try: value = translate(value) except ValueError: diff --git a/src/arpes/widgets.py b/src/arpes/widgets.py index 7b9e0cde..97e7dd6e 100644 --- a/src/arpes/widgets.py +++ b/src/arpes/widgets.py @@ -161,8 +161,8 @@ def onselect(self, verts: NDArray[np.float_]) -> None: if self._on_select is not None: self._on_select(self.ind) - except Exception as err: - logger.debug(f"Exception occurs: {err=}, {type(err)=}") + except Exception: + logger.exception("Exception occurs.") def disconnect(self) -> None: self.lasso.disconnect_events() @@ -222,7 +222,7 @@ class DataArrayView: Look some more into holoviews for different features. https://github.com/pyviz/holoviews/pull/1214 """ - def __init__( + def __init__( # noqa: PLR0913 self, ax: Axes, data: xr.DataArray | None = None, @@ -280,7 +280,7 @@ def handle_select(self, event_click: MouseEvent, event_release: MouseEvent) -> N self._inner_on_select(region) - def attach_selector(self, on_select) -> None: + def attach_selector(self, on_select: Incomplete) -> None: # data should already have been set """[TODO:summary]. @@ -404,7 +404,7 @@ def mask(self): # noqa: ANN202 return self._mask @mask.setter - def mask(self, new_mask) -> None: + def mask(self, new_mask: Incomplete) -> None: """[TODO:summary]. Args: @@ -510,7 +510,7 @@ def compute_parameters() -> dict: ] return dict(itertools.chain(*[list(d.items()) for d in renamed])) - def on_add_new_peak(selection) -> None: + def on_add_new_peak(selection: Incomplete) -> None: """[TODO:summary]. Args: @@ -727,15 +727,15 @@ def clamp(x: int, low: int, high: int) -> int: assert val_x != val_y set_axes(val_x, val_y) - except Exception as err: - logger.debug(f"Exception occurs: {err=}, {type(err)=}") + except Exception: + logger.exception("Exception occurs.") context["axis_button"] = Button(ax_widget_1, "Change Decomp Axes") context["axis_button"].on_clicked(on_change_axes) context["axis_X_input"] = TextBox(ax_widget_2, "Axis X:", initial=str(initial_values[0])) context["axis_Y_input"] = TextBox(ax_widget_3, "Axis Y:", initial=str(initial_values[1])) - def on_select_summed(region) -> None: + def on_select_summed(region: Incomplete) -> None: """[TODO:summary]. Args: From 3cb3c71ff9093be4cb2ce5c5196785c1b71313e7 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 10:28:01 +0900 Subject: [PATCH 03/10] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/plotting/bz.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/arpes/plotting/bz.py b/src/arpes/plotting/bz.py index 8c4db346..b622ac43 100644 --- a/src/arpes/plotting/bz.py +++ b/src/arpes/plotting/bz.py @@ -507,10 +507,12 @@ def draw(self, renderer: Incomplete) -> None: for name, point in zip(names, points, strict=True): x, y, z = point if name == "G": - name = "\\Gamma" + name_tex = "\\Gamma" elif len(name) > 1: - name = name[0] + "_" + name[1] - ax.text(x, y, z, "$" + name + "$", ha="center", va="bottom", color="r") + name_tex = name[0] + "_" + name[1] + else: + name_tex = name + ax.text(x, y, z, f"${name_tex}$", ha="center", va="bottom", color="r") if kpoints is not None: for p in kpoints: From 5973903c0dd40d40ba6fe6d34dbdb24bdf272a60 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 11:09:44 +0900 Subject: [PATCH 04/10] =?UTF-8?q?=F0=9F=92=AC=20=20Change=20target=20pytho?= =?UTF-8?q?n=20version.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cee87c78..ddb04c60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,9 +96,10 @@ lint.ignore = [ "G004", # logging-f-string # "NPY201", # Numpy 2.0, + "ISC001", # single-line-implicit-string-concatenation ] lint.select = ["ALL"] -target-version = "py312" +target-version = "py311" line-length = 100 indent-width = 4 From b62366c07f86f19e3055cd2da4e8d0bdb8c4fcfe Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 11:10:05 +0900 Subject: [PATCH 05/10] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/deep_learning/interpret.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arpes/deep_learning/interpret.py b/src/arpes/deep_learning/interpret.py index 8e8e3e3e..92db5d62 100644 --- a/src/arpes/deep_learning/interpret.py +++ b/src/arpes/deep_learning/interpret.py @@ -137,7 +137,7 @@ def items(self) -> list[InterpretationItem]: def top_losses(self, *, ascending: bool = False) -> list[InterpretationItem]: """Orders the items by loss.""" - def key(item: Incomplete): + def key(item: Incomplete) -> Incomplete: return item.loss if ascending else -item.loss return sorted(self.items, key=key) From e8159fc9ba06c9af266b5494f1caea9a8424cff3 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 11:19:17 +0900 Subject: [PATCH 06/10] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/analysis/band_analysis.py | 13 +++++++++++++ src/arpes/plotting/fit_tool/__init__.py | 2 +- src/arpes/plotting/qt_tool/__init__.py | 2 +- src/arpes/plotting/stack_plot.py | 4 ++-- src/arpes/utilities/widgets.py | 2 +- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/arpes/analysis/band_analysis.py b/src/arpes/analysis/band_analysis.py index cbc6113a..e2b979b8 100644 --- a/src/arpes/analysis/band_analysis.py +++ b/src/arpes/analysis/band_analysis.py @@ -7,6 +7,7 @@ import functools import itertools from itertools import pairwise +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -36,6 +37,18 @@ "fit_for_effective_mass", ) +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + def fit_for_effective_mass( data: xr.DataArray, diff --git a/src/arpes/plotting/fit_tool/__init__.py b/src/arpes/plotting/fit_tool/__init__.py index 6d805cc3..f9c7c22c 100644 --- a/src/arpes/plotting/fit_tool/__init__.py +++ b/src/arpes/plotting/fit_tool/__init__.py @@ -266,7 +266,7 @@ def configure_image_widgets(self) -> None: layout=self.content_layout, ) - def generate_fit_marginal_for( + def generate_fit_marginal_for( # noqa: PLR0913 self, dimensions: tuple[int, ...], column_row: tuple[int, int], diff --git a/src/arpes/plotting/qt_tool/__init__.py b/src/arpes/plotting/qt_tool/__init__.py index f27c32f5..8f210789 100644 --- a/src/arpes/plotting/qt_tool/__init__.py +++ b/src/arpes/plotting/qt_tool/__init__.py @@ -552,7 +552,7 @@ def set_data(self, data: xr.DataArray) -> None: def _qt_tool(data: XrTypes, **kwargs: Incomplete) -> None: """Starts the qt_tool using an input spectrum.""" with contextlib.suppress(TypeError): - data = dill.loads(data) + data = dill.loads(data) # noqa: S301 tool = QtTool() tool.set_data(data) diff --git a/src/arpes/plotting/stack_plot.py b/src/arpes/plotting/stack_plot.py index 8d611012..e6606137 100644 --- a/src/arpes/plotting/stack_plot.py +++ b/src/arpes/plotting/stack_plot.py @@ -65,7 +65,7 @@ @save_plot_provenance -def offset_scatter_plot( +def offset_scatter_plot( # noqa: PLR0913 data: xr.Dataset, name_to_plot: str = "", stack_axis: str = "", @@ -133,7 +133,7 @@ def offset_scatter_plot( skip_colorbar = True if cbarmap is None: skip_colorbar = False - cbar: colorbar.Colorbar | Callable[..., colorbar.Colorbar] + cbar: Callable[..., colorbar.Colorbar] cmap: Callable[..., ColorType] | Callable[..., Callable[..., ColorType]] try: cbar, cmap = colorbarmaps_for_axis[stack_axis] diff --git a/src/arpes/utilities/widgets.py b/src/arpes/utilities/widgets.py index 04b0c2cc..ec432501 100644 --- a/src/arpes/utilities/widgets.py +++ b/src/arpes/utilities/widgets.py @@ -138,7 +138,7 @@ def __init__(self, *args: QWidget | None) -> None: self.toggled.connect(lambda: self.subject.on_next(self.isChecked())) self.subject.subscribe(self.update_ui) - def update_ui(self, value: bool) -> None: + def update_ui(self, *, value: bool) -> None: """Forwards value change to the UI.""" self.setChecked(value) From 39c076918281f1069b7d7002527f05e2c6ecf2f3 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 11:46:16 +0900 Subject: [PATCH 07/10] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/deep_learning/transforms.py | 35 ++++++++++++++++++++++++--- src/arpes/plotting/spin.py | 11 +++------ src/arpes/utilities/qt/app.py | 2 +- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/arpes/deep_learning/transforms.py b/src/arpes/deep_learning/transforms.py index c96313d1..2f286526 100644 --- a/src/arpes/deep_learning/transforms.py +++ b/src/arpes/deep_learning/transforms.py @@ -15,15 +15,44 @@ class Identity: """Represents a reversible identity transform.""" def encodes(self, x: Incomplete) -> Incomplete: + """[TODO:summary]. + + Args: + x: [TODO:description] + + Returns: + [TODO:description] + """ return x def __call__(self, x: Incomplete) -> Incomplete: + """[TODO:summary]. + + Args: + x: [TODO:description] + + Returns: + [TODO:description] + """ return x def decodes(self, x: Incomplete) -> Incomplete: + """[TODO:summary]. + + Args: + x: [TODO:description] + + Returns: + [TODO:description] + """ return x def __repr__(self) -> str: + """[TODO:summary]. + + Returns: + [TODO:description] + """ return "Identity()" @@ -54,9 +83,9 @@ def __post_init__(self) -> None: for t in self.transforms: if isinstance(t, tuple | list): xt, yt = t - t = [xt or _identity, yt or _identity] - - safe_transforms.append(t) + safe_transforms.append([xt or _identity, yt or _identity]) + else: + safe_transforms.append(t) self.original_transforms = self.transforms self.transforms = safe_transforms diff --git a/src/arpes/plotting/spin.py b/src/arpes/plotting/spin.py index f2f77291..f372b15a 100644 --- a/src/arpes/plotting/spin.py +++ b/src/arpes/plotting/spin.py @@ -163,11 +163,9 @@ def spin_polarized_spectrum( # noqa: PLR0913 counts = spin_dr pol = to_intensity_polarization(counts) - ax_left = ax[0] - ax_right = ax[1] + ax_left, ax_right = ax[0], ax[1] - up = counts.down.data - down = counts.up.data + down, up = counts.down.data, counts.up.data energies = spin_dr.coords["eV"].values min_e, max_e = np.min(energies), np.max(energies) @@ -194,8 +192,7 @@ def spin_polarized_spectrum( # noqa: PLR0913 ax_left.set_xlabel(r"\textbf{Kinetic energy} (eV)") ax_left.set_xlim(min_e, max_e) - max_up = np.max(up) - max_down = np.max(down) + max_up, max_down = np.max(up), np.max(down) ax_left.set_ylim(0, max(max_down, max_up) * 1.2) # Plot the polarization and associated statistical error bars @@ -296,7 +293,7 @@ def hue_brightness_plot( assert isinstance(ax, Axes) assert isinstance(fig, Figure) x, y = data.coords[data.intensity.dims[0]].values, data.coords[data.intensity.dims[1]].values - extent = [y[0], y[-1], x[0], x[-1]] + extent = (y[0], y[-1], x[0], x[-1]) ax.imshow( polarization_intensity_to_color(data, **kwargs), extent=extent, diff --git a/src/arpes/utilities/qt/app.py b/src/arpes/utilities/qt/app.py index dc06d009..1cfffb36 100644 --- a/src/arpes/utilities/qt/app.py +++ b/src/arpes/utilities/qt/app.py @@ -143,7 +143,7 @@ def set_colormap(self, colormap: Colormap | str) -> None: if isinstance(view, DataArrayImageView): view.setColorMap(cmap) - def generate_marginal_for( + def generate_marginal_for( # noqa: PLR0913 self, dimensions: tuple[int, ...], column: int, From 945f3d25f06781cbcc87e611b661bc19b3c65b71 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 13:03:18 +0900 Subject: [PATCH 08/10] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/analysis/band_analysis.py | 13 ++++++------- src/arpes/fits/utilities.py | 2 +- src/arpes/plotting/annotations.py | 3 ++- src/arpes/plotting/fit_tool/__init__.py | 2 +- src/arpes/xarray_extensions.py | 9 +++++---- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/arpes/analysis/band_analysis.py b/src/arpes/analysis/band_analysis.py index e2b979b8..abe3049f 100644 --- a/src/arpes/analysis/band_analysis.py +++ b/src/arpes/analysis/band_analysis.py @@ -261,7 +261,7 @@ def dataarray_for_value(param_name: str, i: int = i, *, is_value: bool) -> xr.Da @update_provenance("Fit bands from pattern") -def fit_patterned_bands( +def fit_patterned_bands( # noqa: PLR0913 arr: xr.DataArray, band_set: dict[Incomplete, Incomplete], fit_direction: str = "", @@ -293,10 +293,9 @@ def fit_patterned_bands( band_set: dictionary with bands and points along the spectrum fit_direction (str): stray (float, optional): - orientation: edc or mdc - direction_normal - preferred_k_direction - dataset: if True, return as Dataset + background (bool): + interactive(bool): + dataset(bool): if true, return as xr.Dataset. Returns: Dataset or DataArray, as controlled by the parameter "dataset" @@ -309,7 +308,7 @@ def fit_patterned_bands( free_directions = list(arr.dims) free_directions.remove(fit_direction) - def resolve_partial_bands_from_description( + def resolve_partial_bands_from_description( # noqa: PLR0913 coord_dict: dict[str, Incomplete], name: str = "", band: Incomplete = None, @@ -525,7 +524,7 @@ def fit_bands( # be stable closest_model_params = initial_fits # fix me dist = float("inf") - frozen_coordinate = tuple(coordinate[k] for k in template.dims) + frozen_coordinate = tuple(coordinate[str(k)] for k in template.dims) for c, v in all_fit_parameters.items(): delta = np.array(c) - frozen_coordinate current_distance = delta.dot(delta) diff --git a/src/arpes/fits/utilities.py b/src/arpes/fits/utilities.py index bee785e6..d31ec0dd 100644 --- a/src/arpes/fits/utilities.py +++ b/src/arpes/fits/utilities.py @@ -230,7 +230,7 @@ def broadcast_model( # noqa: PLR0913 def unwrap(result_data: str) -> object: # (Unpickler) # using the lmfit deserialization and serialization seems slower than double pickling # with dill - return dill.loads(result_data) + return dill.loads(result_data) # noqa: S301 exe_results = [(unwrap(res), residual, cs) for res, residual, cs in exe_results] diff --git a/src/arpes/plotting/annotations.py b/src/arpes/plotting/annotations.py index 4672b121..06eca149 100644 --- a/src/arpes/plotting/annotations.py +++ b/src/arpes/plotting/annotations.py @@ -201,7 +201,7 @@ def annotate_cuts( def annotate_point( ax: Axes | Axes3D, location: Sequence[float], - delta: tuple[float, ...] = (), + delta: tuple[float, float] | tuple[float, float, float] | None = None, **kwargs: Unpack[MPLTextParam], ) -> None: """Annotates a point or high symmetry location into a plot.""" @@ -220,6 +220,7 @@ def annotate_point( -0.05, 0.05, ) + assert isinstance(delta, tuple) if "color" not in kwargs: kwargs["color"] = "red" diff --git a/src/arpes/plotting/fit_tool/__init__.py b/src/arpes/plotting/fit_tool/__init__.py index f9c7c22c..11ad2839 100644 --- a/src/arpes/plotting/fit_tool/__init__.py +++ b/src/arpes/plotting/fit_tool/__init__.py @@ -500,7 +500,7 @@ def set_data(self, data: xr.Dataset) -> None: def _fit_tool(data: xr.Dataset) -> None: """Starts the fitting inspection tool using an input fit result Dataset.""" with contextlib.suppress(TypeError): - data = dill.loads(data) + data = dill.loads(data) # noqa: S301 # some sanity checks that we were actually passed a collection of fit results assert isinstance(data, xr.Dataset) diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 7b8e780c..4bcde284 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -1196,7 +1196,7 @@ def unpack_dim(dim_name: str) -> str: # remove missing dimensions from selection for permissiveness # and to transparent composing of regions region = {k: process_region_selector(v, k) for k, v in region.items() if k in obj.dims} - obj = obj.sel(**region) + obj = obj.sel(region) return obj @@ -2520,13 +2520,14 @@ def iterate_axis( self, axis_name_or_axes: list[str] | str, ) -> Generator[tuple[dict[str, float], XrTypes], str, None]: - """[TODO:summary]. + """Generator to extract data for along the specified axis. Args: axis_name_or_axes: [TODO:description] - Returns: - [TODO:description] + Returns: (tuple[dict[str, float], XrTypes]) + dict object represents the axis(dim) name and it's value. + XrTypes object the corresponding data. """ assert isinstance(self._obj, xr.DataArray | xr.Dataset) if isinstance(axis_name_or_axes, str): From 8e230e05c3fc425e6727d9c9d73f493aca25ccef Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 13:25:13 +0900 Subject: [PATCH 09/10] =?UTF-8?q?=F0=9F=9A=A8=20=20Introduce=20helper=20fu?= =?UTF-8?q?nction.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/endstations/__init__.py | 57 +++++++++++++++++-------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/src/arpes/endstations/__init__.py b/src/arpes/endstations/__init__.py index 277ef4cb..7a64373d 100644 --- a/src/arpes/endstations/__init__.py +++ b/src/arpes/endstations/__init__.py @@ -278,7 +278,7 @@ def concatenate_frames( frames.sort(key=lambda x: x.coords[scan_coord]) return xr.concat(frames, scan_coord) - def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path | str]: + def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path]: """Determine all files and frames associated to this piece of data. This always needs to be overridden in subclasses to handle data appropriately. @@ -358,31 +358,12 @@ def postprocess_final( coord_names: tuple[str, ...] = tuple(sorted([str(c) for c in data.dims if c != "cycle"])) spectrum_type = _spectrum_type(coord_names) - if "phi" not in data.coords: - data.coords["phi"] = 0 - for s in data.S.spectra: - s.coords["phi"] = 0 - - if spectrum_type is not None: - data.attrs["spectrum_type"] = spectrum_type - if "spectrum" in data.data_vars: - data.spectrum.attrs["spectrum_type"] = spectrum_type - - ls = [data, *data.S.spectra] - for a_data in ls: - for k, key_fn in self.ATTR_TRANSFORMS.items(): - if k in a_data.attrs: - transformed = key_fn(a_data.attrs[k]) - if isinstance(transformed, dict): - a_data.attrs.update(transformed) - else: - a_data.attrs[k] = transformed - - for a_data in ls: - for k, v in self.MERGE_ATTRS.items(): - a_data.attrs.setdefault(k, v) - - for a_data in [_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in ls]: + modified_data = [ + self._modify_a_data(a_data, spectrum_type) for a_data in [data, *data.S.spectra] + ] + for a_data in [ + _ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in modified_data + ]: if "chi" in a_data.coords and "chi_offset" not in a_data.attrs: a_data.attrs["chi_offset"] = a_data.coords["chi"].item() @@ -449,6 +430,30 @@ def load(self, scan_desc: ScanDesc | None = None, **kwargs: Incomplete) -> xr.Da return concatted + def _modify_a_data(self, a_data: DataType, spectrum_type: str | None) -> DataType: + """Helper function to modify the Dataset and DataArray that are contained in the Dataset. + + Args: + a_data: [TODO:description] + spectrum_type: [TODO:description] + + Returns: + [TODO:description] + """ + if "phi" not in a_data.coords: + a_data.coords["phi"] = 0 + a_data.attrs["spectrum_type"] = spectrum_type + for k, key_fn in self.ATTR_TRANSFORMS.items(): + if k in a_data.attrs: + transformed = key_fn(a_data.attrs[k]) + if isinstance(transformed, dict): + a_data.attrs.update(transformed) + else: + a_data.attrs[k] = transformed + for k, v in self.MERGE_ATTRS.items(): + a_data.attrs.setdefault(k, v) + return a_data + def _spectrum_type( coord_names: Sequence[str], From fa7b283dcbf12d45c1586c1cb3ca971177cb9a5f Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 19 Mar 2024 14:18:39 +0900 Subject: [PATCH 10/10] =?UTF-8?q?=F0=9F=9A=A8=20=20Remove=20ruff=20Warning?= =?UTF-8?q?=20in=20xarray=5Fextensions.py=20=20=20=20=20-=20=20PLW2901=20?= =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/analysis/mask.py | 14 +++--- src/arpes/endstations/nexus_utils.py | 3 +- src/arpes/endstations/plugin/ANTARES.py | 31 +++++++++---- src/arpes/endstations/plugin/BL10_SARPES.py | 2 +- .../plugin/Elettra_spectromicroscopy.py | 23 ++++++++-- src/arpes/endstations/plugin/HERS.py | 11 ++++- src/arpes/endstations/plugin/IF_UMCS.py | 5 ++- src/arpes/endstations/plugin/MAESTRO.py | 6 ++- src/arpes/endstations/plugin/kaindl.py | 4 +- src/arpes/endstations/plugin/merlin.py | 4 +- src/arpes/plotting/bz.py | 45 +++++++++---------- .../bz_tool/CoordinateOffsetWidget.py | 17 ++++++- .../bz_tool/RangeOrSingleValueWidget.py | 2 + src/arpes/provenance.py | 2 + src/arpes/utilities/collections.py | 4 +- src/arpes/utilities/conversion/core.py | 8 ++-- src/arpes/xarray_extensions.py | 21 ++++++--- 17 files changed, 136 insertions(+), 66 deletions(-) diff --git a/src/arpes/analysis/mask.py b/src/arpes/analysis/mask.py index ae4e8121..ec94e8b8 100644 --- a/src/arpes/analysis/mask.py +++ b/src/arpes/analysis/mask.py @@ -55,7 +55,7 @@ def polys_to_mask( radius: float = 0, *, invert: bool = False, -) -> NDArray[np.float_] | NDArray[np.bool_]: +) -> NDArray[np.bool_]: """Converts a mask definition in terms of the underlying polygon to a True/False mask array. Uses the coordinates and shape of the target data in order to determine which pixels @@ -120,18 +120,20 @@ def apply_mask_to_coords( Returns: The masked data. """ - p = Path(mask["poly"]) - as_array = np.stack([data.data_vars[d].values for d in dims], axis=-1) shape = as_array.shape dest_shape = shape[:-1] new_shape = [np.prod(dest_shape), len(dims)] + mask_array = ( + Path(np.array(mask["poly"])) + .contains_points(as_array.reshape(new_shape)) + .reshape(dest_shape) + ) - mask = p.contains_points(as_array.reshape(new_shape)).reshape(dest_shape) if invert: - mask = np.logical_not(mask) + mask_array = np.logical_not(mask_array) - return mask + return mask_array @update_provenance("Apply boolean mask to data") diff --git a/src/arpes/endstations/nexus_utils.py b/src/arpes/endstations/nexus_utils.py index 26f06542..aad506d9 100644 --- a/src/arpes/endstations/nexus_utils.py +++ b/src/arpes/endstations/nexus_utils.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from _typeshed import Incomplete import xarray as xr __all__ = ("read_data_attributes_from",) @@ -63,7 +64,7 @@ class Target: value: Any = None - def read_h5(self, g, path) -> None: + def read_h5(self, g: Incomplete, path: Incomplete) -> None: self.value = None self.value = self.read(read_group_data(g)) diff --git a/src/arpes/endstations/plugin/ANTARES.py b/src/arpes/endstations/plugin/ANTARES.py index 0248bd57..668d8d88 100644 --- a/src/arpes/endstations/plugin/ANTARES.py +++ b/src/arpes/endstations/plugin/ANTARES.py @@ -73,7 +73,11 @@ } -def parse_axis_name_from_long_name(name: str, keep_segments: int = 1, separator: str = "_") -> str: +def parse_axis_name_from_long_name( + name: str, + keep_segments: int = 1, + separator: str = "_", +) -> str: segments = name.split("/")[-keep_segments:] segments = [s.replace("'", "") for s in segments] return separator.join(segments) @@ -99,14 +103,18 @@ def infer_scan_type_from_data(group: dict) -> str: raise NotImplementedError(scan_name) -class ANTARESEndstation(HemisphericalEndstation, SynchrotronEndstation, SingleFileEndstation): +class ANTARESEndstation( + HemisphericalEndstation, + SynchrotronEndstation, + SingleFileEndstation, +): """Implements data loading for ANTARES at SOLEIL. There's not too much metadata here except what comes with the analyzer settings. """ PRINCIPAL_NAME = "ANTARES" - ALIASES: ClassVar[list] = [] + ALIASES: ClassVar[list[str]] = [] _TOLERATED_EXTENSIONS: ClassVar[set[str]] = {".nxs"} @@ -120,14 +128,12 @@ def load_top_level_scan( ) -> xr.Dataset: """Reads a spectrum from the top level group in a NeXuS scan format. - [TODO:description] - Args: group ([TODO:type]): [TODO:description] scan_desc: [TODO:description] spectrum_index ([TODO:type]): [TODO:description] - Returns: + Returns (xr.Dataset): [TODO:description] """ if scan_desc: @@ -177,7 +183,10 @@ def get_coords(self, group: Incomplete, scan_name: str, shape: Incomplete): ( name if set_names[name] == 1 - else parse_axis_name_from_long_name(actuator_long_names[i], keep_segments) + else parse_axis_name_from_long_name( + actuator_long_names[i], + keep_segments, + ) ) for i, name in enumerate(actuator_names) ] @@ -241,13 +250,17 @@ def take_last(vs): energy = data[e_keys[0]][0], data[e_keys[1]][0], data[e_keys[2]][0] angle = data[ang_keys[0]][0], data[ang_keys[1]][0], data[ang_keys[2]][0] - def get_first(item): + def get_first(item: NDArray[np.float_] | float): if isinstance(item, np.ndarray): return item.ravel()[0] return item - def build_axis(low: float, high: float, step_size: float) -> tuple[NDArray[np.float_], int]: + def build_axis( + low: float, + high: float, + step_size: float, + ) -> tuple[NDArray[np.float_], int]: # this might not work out to be the right thing to do, we will see low, high, step_size = get_first(low), get_first(high), get_first(step_size) est_n: int = int((high - low) / step_size) diff --git a/src/arpes/endstations/plugin/BL10_SARPES.py b/src/arpes/endstations/plugin/BL10_SARPES.py index 55e6801b..da58da44 100644 --- a/src/arpes/endstations/plugin/BL10_SARPES.py +++ b/src/arpes/endstations/plugin/BL10_SARPES.py @@ -121,7 +121,7 @@ def load_single_region( """Loads a single region for multi-region scans.""" from arpes.load_pxt import read_single_pxt - name, _ = Path(region_path).stem + name = Path(region_path).stem num = name[-3:] pxt_data = read_single_pxt(region_path, allow_multiple=True) diff --git a/src/arpes/endstations/plugin/Elettra_spectromicroscopy.py b/src/arpes/endstations/plugin/Elettra_spectromicroscopy.py index 08b04674..aeaf5e60 100644 --- a/src/arpes/endstations/plugin/Elettra_spectromicroscopy.py +++ b/src/arpes/endstations/plugin/Elettra_spectromicroscopy.py @@ -109,7 +109,10 @@ def unwrap_bytestring( ) -class SpectromicroscopyElettraEndstation(HemisphericalEndstation, SynchrotronEndstation): +class SpectromicroscopyElettraEndstation( + HemisphericalEndstation, + SynchrotronEndstation, +): """Data loading for the nano-ARPES beamline "Spectromicroscopy Elettra". Information available on the beamline can be accessed @@ -145,7 +148,12 @@ def files_for_search(cls: type, directory: str | Path) -> list[Path]: else: base_files = [*base_files, Path(file)] - return list(filter(lambda f: Path(f).suffix in cls._TOLERATED_EXTENSIONS, base_files)) + return list( + filter( + lambda f: Path(f).suffix in cls._TOLERATED_EXTENSIONS, + base_files, + ) + ) ANALYZER_INFORMATION: ClassVar[dict[str, str | float | bool]] = { "analyzer": "Custom: in vacuum hemispherical", @@ -228,7 +236,10 @@ def concatenate_frames( return xr.Dataset({"spectrum": xr.concat(fs, scan_coord)}) - def resolve_frame_locations(self, scan_desc: ScanDesc | None = None) -> list[Path]: + def resolve_frame_locations( + self, + scan_desc: ScanDesc | None = None, + ) -> list[Path]: """Determines all files associated with a given scan. This beamline saves several HDF files in scan associated folders, so this @@ -269,7 +280,11 @@ def load_single_frame( return xr.Dataset(arrays) - def postprocess_final(self, data: xr.Dataset, scan_desc: ScanDesc | None = None) -> xr.Dataset: + def postprocess_final( + self, + data: xr.Dataset, + scan_desc: ScanDesc | None = None, + ) -> xr.Dataset: """Performs final postprocessing of the data. This mostly amounts to: diff --git a/src/arpes/endstations/plugin/HERS.py b/src/arpes/endstations/plugin/HERS.py index 8e3e7ef2..4a1868fb 100644 --- a/src/arpes/endstations/plugin/HERS.py +++ b/src/arpes/endstations/plugin/HERS.py @@ -23,7 +23,10 @@ __all__ = ("HERSEndstation",) -class HERSEndstation(SynchrotronEndstation, HemisphericalEndstation): +class HERSEndstation( + SynchrotronEndstation, + HemisphericalEndstation, +): """Implements data loading at the ALS HERS beamline. This should be unified with the FITs endstation code, but I don't have any projects at BL10 @@ -33,7 +36,11 @@ class HERSEndstation(SynchrotronEndstation, HemisphericalEndstation): PRINCIPAL_NAME = "ALS-BL1001" ALIASES: ClassVar[list[str]] = ["ALS-BL1001", "HERS", "ALS-HERS", "BL1001"] - def load(self, scan_desc: ScanDesc | None = None, **kwargs: Incomplete) -> xr.Dataset: + def load( + self, + scan_desc: ScanDesc | None = None, + **kwargs: Incomplete, + ) -> xr.Dataset: """Loads HERS data from FITS files. Shares a lot in common with Lanzara group formats. Args: diff --git a/src/arpes/endstations/plugin/IF_UMCS.py b/src/arpes/endstations/plugin/IF_UMCS.py index 3e663617..72d77b7f 100644 --- a/src/arpes/endstations/plugin/IF_UMCS.py +++ b/src/arpes/endstations/plugin/IF_UMCS.py @@ -23,7 +23,10 @@ __all__ = ("IF_UMCS",) -class IF_UMCS(HemisphericalEndstation, SingleFileEndstation): # noqa: N801 +class IF_UMCS( # noqa: N801 + HemisphericalEndstation, + SingleFileEndstation, +): """Implements loading xy text files from the Specs Prodigy software.""" PRINCIPAL_NAME = "IF_UMCS" diff --git a/src/arpes/endstations/plugin/MAESTRO.py b/src/arpes/endstations/plugin/MAESTRO.py index 833182e7..25112acd 100644 --- a/src/arpes/endstations/plugin/MAESTRO.py +++ b/src/arpes/endstations/plugin/MAESTRO.py @@ -29,7 +29,11 @@ __all__ = ("MAESTROMicroARPESEndstation", "MAESTRONanoARPESEndstation") -class MAESTROARPESEndstationBase(SynchrotronEndstation, HemisphericalEndstation, FITSEndstation): +class MAESTROARPESEndstationBase( + SynchrotronEndstation, + HemisphericalEndstation, + FITSEndstation, +): """Common code for the MAESTRO ARPES endstations at the Advanced Light Source.""" PRINCIPAL_NAME = "" diff --git a/src/arpes/endstations/plugin/kaindl.py b/src/arpes/endstations/plugin/kaindl.py index 73103a2a..251cfdc5 100644 --- a/src/arpes/endstations/plugin/kaindl.py +++ b/src/arpes/endstations/plugin/kaindl.py @@ -180,8 +180,8 @@ def concatenate_frames( frames.sort(key=lambda x: x.coords[axis_name]) return xr.concat(frames, axis_name) - except Exception as err: - logger.info(f"Exception occurs. {err=}, {type(err)=}") + except Exception: + logger.exception("Exception occurs.") return None def postprocess_final(self, data: xr.Dataset, scan_desc: ScanDesc | None = None) -> xr.Dataset: diff --git a/src/arpes/endstations/plugin/merlin.py b/src/arpes/endstations/plugin/merlin.py index 5144d0f0..8bce7e8b 100644 --- a/src/arpes/endstations/plugin/merlin.py +++ b/src/arpes/endstations/plugin/merlin.py @@ -202,9 +202,7 @@ def load_single_frame( scan_desc["path"] = frame_path return self.load_SES_nc(scan_desc=scan_desc, **kwargs) - original_data_loc: Path | str = scan_desc.get("path", scan_desc.get("file")) - - p = Path(original_data_loc) + p = Path(scan_desc.get("path", scan_desc.get("file", ""))) # find files with same name stem, indexed in format R### regions = find_ses_files_associated(p, separator="R") diff --git a/src/arpes/plotting/bz.py b/src/arpes/plotting/bz.py index b622ac43..88c24aa5 100644 --- a/src/arpes/plotting/bz.py +++ b/src/arpes/plotting/bz.py @@ -188,7 +188,7 @@ def apply_transformations( def plot_plane_to_bz( cell: Sequence[Sequence[float]] | NDArray[np.float_], plane: str | list[NDArray[np.float_]], - ax: Axes, + ax: Axes3D, special_points: dict[str, NDArray[np.float_]] | None = None, facecolor: ColorType = "red", ) -> None: @@ -209,7 +209,7 @@ def plot_plane_to_bz( if isinstance(plane, str): plane_points: list[NDArray[np.float_]] = process_kpath( plane, - cell, + np.array(cell), special_points=special_points, )[0] else: @@ -226,7 +226,7 @@ def plot_plane_to_bz( def plot_data_to_bz( - data: DataType, + data: xr.DataArray, cell: Sequence[Sequence[float]] | NDArray[np.float_], **kwargs: Incomplete, ) -> Path | tuple[Figure, Axes]: @@ -313,10 +313,10 @@ def plot_data_to_bz2d( # noqa: PLR0913 def plot_data_to_bz3d( - data: DataType, + data: xr.DataArray, cell: Sequence[Sequence[float]] | NDArray[np.float_], **kwargs: Incomplete, -) -> None: +) -> Path | tuple[Figure, Axes]: """Plots ARPES data onto a 3D Brillouin zone.""" msg = "plot_data_to_bz3d is not implemented yet." logger.debug(f"id of data: {data.attrs.get('id', None)}") @@ -533,7 +533,7 @@ def draw(self, renderer: Incomplete) -> None: def annotate_special_paths( ax: Axes, - paths: list[str] | str, + paths: list[str] | str = "", cell: NDArray[np.float_] | Sequence[Sequence[float]] | None = None, offset: dict[str, Sequence[float]] | None = None, special_points: dict[str, NDArray[np.float_]] | None = None, @@ -541,17 +541,11 @@ def annotate_special_paths( **kwargs: Incomplete, ) -> None: """Annotates user indicated paths in k-space by plotting lines (or points) over the BZ.""" - logger.debug(f"annotate-ax: {ax}") - logger.debug(f"annotate-paths: {paths}") - logger.debug(f"annotate-cell: {cell}") - logger.debug(f"annotate-offset: {offset}") - logger.debug(f"annotate-special_points: {special_points}") - logger.debug(f"annotate-labels: {labels}") if kwargs: for k, v in kwargs.items(): logger.debug(f"kwargs: kyes: {k}, value: {v}") - if paths == "": + if not paths: msg = "Must provide a proper path." raise ValueError(msg) @@ -668,7 +662,7 @@ def twocell_to_bz1(cell: NDArray[np.float_]) -> Incomplete: def bz2d_plot( - cell: Sequence[Sequence[float]], + cell: Sequence[Sequence[float]] | NDArray[np.float_], paths: str | list[float] | None = None, points: Sequence[float] | None = None, repeat: tuple[int, int] | None = None, @@ -687,16 +681,8 @@ def bz2d_plot( Plots a Brillouin zone corresponding to a given unit cell """ - logger.debug(f"bz2d_plot-cell: {cell}") - logger.debug(f"bz2d_plot-paths: {paths}") - logger.debug(f"bz2d_plot-points: {points}") - logger.debug(f"bz2d_plot-repeat: {repeat}") - logger.debug(f"bz2d_plot-transformations: {transformations}") - logger.debug(f"bz2d_plot-hide_ax: {hide_ax}") - logger.debug(f"bz2d_plot-vectors: {vectors}") - logger.debug(f"bz2d_plot-set_equal_aspect: {set_equal_aspect}") kpoints = points - bz1, icell, cell = twocell_to_bz1(cell) + bz1, icell, cell = twocell_to_bz1(np.array(cell)) logger.debug(f"bz1 : {bz1}") if ax is None: ax = plt.axes() @@ -710,6 +696,12 @@ def bz2d_plot( path_string = cell_structure.special_path if paths == "all" else paths paths = [] for names in parse_path_string(path_string): + """ + >>> parse_path_string('GX') + [['G', 'X']] + >>> parse_path_string('GX,M1A') + [['G', 'X'], ['M1', 'A']] + """ points = [] for name in names: points.append(np.dot(icell.T, special_points[name])) @@ -774,7 +766,12 @@ def bz2d_plot( ) if paths is not None: - annotate_special_paths(ax, paths, offset=offset, transformations=transformations) + annotate_special_paths( + ax, + paths, + offset=offset, + transformations=transformations, + ) if kpoints is not None: for p in kpoints: diff --git a/src/arpes/plotting/bz_tool/CoordinateOffsetWidget.py b/src/arpes/plotting/bz_tool/CoordinateOffsetWidget.py index 642a8d57..7a6324e2 100644 --- a/src/arpes/plotting/bz_tool/CoordinateOffsetWidget.py +++ b/src/arpes/plotting/bz_tool/CoordinateOffsetWidget.py @@ -3,10 +3,24 @@ from __future__ import annotations from functools import partial +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING from PySide6 import QtWidgets +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + if TYPE_CHECKING: from _typeshed import Incomplete from PySide6.QtCore import QEvent @@ -28,7 +42,7 @@ def __init__( ) -> None: """Configures utility label, an inner control, and a linked spinbox for text entry.""" super().__init__(title=coordinate_name, parent=parent) - + logger.debug(f"value = {value} has not been used.") self.layout: QGridLayout = QtWidgets.QGridLayout(self) self.label = QtWidgets.QLabel("Value: ") @@ -56,6 +70,7 @@ def value_changed( if self._prevent_change_events: return + logger.debug(f"event={event} has not been used.") self._prevent_change_events = True self.slider.setValue(source.value()) self.spinbox.setValue(source.value()) diff --git a/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py b/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py index 81b7e411..8f0885f9 100644 --- a/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py +++ b/src/arpes/plotting/bz_tool/RangeOrSingleValueWidget.py @@ -53,6 +53,7 @@ def __init__( """ super().__init__(title=coordinate_name, parent=parent) + logger.debug(f"value = {value} has not been used.") self.layout: QGridLayout = QtWidgets.QGridLayout(self) self.label = QtWidgets.QLabel("Value: ") @@ -83,6 +84,7 @@ def value_changed(self, event: Incomplete, source: Incomplete) -> None: if self._prevent_change_events: return + logger.debug(f"event={event} has not been used.") self._prevent_change_events = True self.slider.setValue(source.value()) self.spinbox.setValue(source.value()) diff --git a/src/arpes/provenance.py b/src/arpes/provenance.py index c1fa5215..5736aeaf 100644 --- a/src/arpes/provenance.py +++ b/src/arpes/provenance.py @@ -80,6 +80,8 @@ class Provenance(TypedDict, total=False): new_axis: str transformed_vars: list[str] # + # + parant_id: tuple[str, str] occupation_ratio: float # correlation: bool diff --git a/src/arpes/utilities/collections.py b/src/arpes/utilities/collections.py index 7a5e6a0f..a6d5cf2a 100644 --- a/src/arpes/utilities/collections.py +++ b/src/arpes/utilities/collections.py @@ -19,8 +19,8 @@ def deep_update(destination: dict[str, T], source: dict[str, T]) -> dict[str, T] Instead recurse down from the root and update as appropriate. Args: - destination: - source: + destination: dict object to be updated. + source: source dict Returns: The destination item diff --git a/src/arpes/utilities/conversion/core.py b/src/arpes/utilities/conversion/core.py index 69cf1dfe..bf89ba89 100644 --- a/src/arpes/utilities/conversion/core.py +++ b/src/arpes/utilities/conversion/core.py @@ -658,12 +658,12 @@ def _extract_symmetry_point( """[TODO:summary]. Args: - name (str): [TODO:description] - arr (xr.DataArray): [TODO:description] + name (str): Name of the symmetry points, such as G, X, L. + arr (xr.DataArray): ARPES data. extend_to_edge (bool): [TODO:description] - Returns: - [TODO:description] + Returns: dict(Hashable, float) + Return dict object as the symmetry point """ raw_point: dict[Hashable, float] = arr.attrs["symmetry_points"][name] G = arr.attrs["symmetry_points"]["G"] diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 4bcde284..cc2c4474 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -684,9 +684,15 @@ def symmetry_points( The original version was something complicated, but the coding seemed to be in process and the purpose was unclear, so it was streamlined considerably. + + + Returns (dict[str, dict[str, float]]): + Dict object representing the symmpetry points in the ARPES data. + + Examples: + example of "symmetry_points": symmetry_points = {"G": {"phi": 0.405}} """ symmetry_points: dict[str, dict[str, float]] = {} - # An example of "symmetry_points": symmetry_points = {"G": {"phi": 0.405}} our_symmetry_points = self._obj.attrs.get("symmetry_points", {}) symmetry_points.update(our_symmetry_points) @@ -1191,12 +1197,17 @@ def unpack_dim(dim_name: str) -> str: return dim_name for region in regions: - region = {unpack_dim(k): v for k, v in normalize_region(region).items()} - # remove missing dimensions from selection for permissiveness # and to transparent composing of regions - region = {k: process_region_selector(v, k) for k, v in region.items() if k in obj.dims} - obj = obj.sel(region) + obj = obj.sel( + { + k: process_region_selector(v, k) + for k, v in { + unpack_dim(k): v for k, v in normalize_region(region).items() + }.items() + if k in obj.dims + }, + ) return obj