From 2690bd6a265ae8fc7b49a5419337bed2d6ad4055 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sun, 21 Apr 2024 21:16:39 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/_typing.py | 7 ++++--- src/arpes/fits/fit_models/x_model_mixin.py | 2 +- src/arpes/plotting/fits.py | 8 ++++---- src/arpes/plotting/utils.py | 3 +-- src/arpes/provenance.py | 1 + src/arpes/utilities/jupyter.py | 4 ++-- src/arpes/workflow.py | 2 +- src/arpes/xarray_extensions.py | 7 +++++-- tests/test_basic_data_loading.py | 12 ++++++------ 9 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/arpes/_typing.py b/src/arpes/_typing.py index 6d915606..19b086c5 100644 --- a/src/arpes/_typing.py +++ b/src/arpes/_typing.py @@ -184,7 +184,8 @@ class AnalyzerInfo(TypedDict, total=False): see analyzer_info in xarray_extensions.py (around line# 1490) """ - name: str + analyzer: str + analyzer_name: str lens_mode: str | None lens_mode_name: str | None acquisition_mode: str @@ -200,7 +201,7 @@ class AnalyzerInfo(TypedDict, total=False): work_function: float # is_slit_vertical: bool - radius: str | float + analyzer_radius: str | float class _PumpInfo(TypedDict, total=False): @@ -337,7 +338,7 @@ class ExperimentInfo( AnalyzerInfo, total=False, ): - pass + analyzer_detail: AnalyzerInfo class ARPESAttrs(Spectrometer, LightSourceInfo, SampleInfo, total=False): diff --git a/src/arpes/fits/fit_models/x_model_mixin.py b/src/arpes/fits/fit_models/x_model_mixin.py index e786578a..795c2f73 100644 --- a/src/arpes/fits/fit_models/x_model_mixin.py +++ b/src/arpes/fits/fit_models/x_model_mixin.py @@ -282,7 +282,7 @@ def _real_data_etc_from_xarray( real_data, flat_data = data.values, data.values assert len(real_data.shape) == self.n_dims coord_values = {} - new_dim_order: list[str] + new_dim_order: list[str] = [] if self.n_dims == 1: coord_values["x"] = data.coords[next(iter(data.indexes))].values else: diff --git a/src/arpes/plotting/fits.py b/src/arpes/plotting/fits.py index 1a42b2cc..01edae34 100644 --- a/src/arpes/plotting/fits.py +++ b/src/arpes/plotting/fits.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: import lmfit as lf + from numpy.typing import NDArray def plot_fit(model_result: lf.Model, ax: Axes | None = None) -> None: @@ -49,11 +50,10 @@ def plot_fit(model_result: lf.Model, ax: Axes | None = None) -> None: ax.set_xlim(left=np.min(x), right=np.max(x)) -def plot_fits(model_results: list[lf.Model], ax: Axes | None = None) -> None: +def plot_fits(model_results: list[lf.Model], axs: NDArray[np.object_] | None = None) -> None: """Plots several fits onto a grid of axes.""" n_results = len(model_results) - if ax is None: - _fig, ax, _ax_extra = simple_ax_grid(n_results, sharex="col", sharey="row") + axs = axs if axs else simple_ax_grid(n_results, sharex="col", sharey="row")[1] - for axi, model_result in zip(ax, model_results, strict=False): + for axi, model_result in zip(axs, model_results, strict=False): plot_fit(model_result, ax=axi) diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index ba94f56d..c9c196cb 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -15,7 +15,7 @@ from datetime import UTC from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Unpack +from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type import matplotlib as mpl import matplotlib.pyplot as plt @@ -316,7 +316,6 @@ def simple_ax_grid( ax, ax_rest = ax.ravel()[:n_axes], ax.ravel()[n_axes:] for axi in ax_rest: invisible_axes(axi) - return fig, ax, ax_rest diff --git a/src/arpes/provenance.py b/src/arpes/provenance.py index 657f7f16..93ddade1 100644 --- a/src/arpes/provenance.py +++ b/src/arpes/provenance.py @@ -52,6 +52,7 @@ class Provenance(TypedDict, total=False): VERSION: str jupyter_notebook_name: str + name: str record: Provenance jupyter_context: list[str] diff --git a/src/arpes/utilities/jupyter.py b/src/arpes/utilities/jupyter.py index af75f76b..b8930232 100644 --- a/src/arpes/utilities/jupyter.py +++ b/src/arpes/utilities/jupyter.py @@ -130,7 +130,7 @@ def get_full_notebook_information() -> NoteBookInfomation | None: return None -def get_notebook_name() -> str | None: +def get_notebook_name() -> str: """Gets the unqualified name of the running Jupyter notebook if not password protected. As an example, if you were running a notebook called "Doping-Analysis.ipynb" @@ -142,7 +142,7 @@ def get_notebook_name() -> str | None: jupyter_info = get_full_notebook_information() if jupyter_info: return Path(jupyter_info["session"]["notebook"]["name"]).stem - return None + return "" def generate_logfile_path() -> Path: diff --git a/src/arpes/workflow.py b/src/arpes/workflow.py index 524e6aea..848a2a28 100644 --- a/src/arpes/workflow.py +++ b/src/arpes/workflow.py @@ -146,7 +146,7 @@ def go_to_figures() -> None: _open_path(path) -def get_running_context() -> tuple[Incomplete, Path]: +def get_running_context() -> tuple[str, Path]: return get_notebook_name(), Path.cwd() diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index b0f77251..33b6e2f8 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -1590,11 +1590,14 @@ def undulator_info(self) -> dict[str, str | float | None]: def analyzer_detail(self) -> AnalyzerInfo: """Details about the analyzer, its capabilities, and metadata.""" return { - "name": self._obj.attrs.get("analyzer_name", self._obj.attrs.get("analyzer", "")), + "analyzer_name": self._obj.attrs.get( + "analyzer_name", + self._obj.attrs.get("analyzer", ""), + ), "parallel_deflectors": self._obj.attrs.get("parallel_deflectors", False), "perpendicular_deflectors": self._obj.attrs.get("perpendicular_deflectors", False), "analyzer_type": self._obj.attrs.get("analyzer_type", ""), - "radius": self._obj.attrs.get("analyzer_radius", np.nan), + "analyzer_radius": self._obj.attrs.get("analyzer_radius", np.nan), } @property diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py index 25560c72..4c4321d4 100644 --- a/tests/test_basic_data_loading.py +++ b/tests/test_basic_data_loading.py @@ -71,8 +71,8 @@ class TestMetadata: "probe_detail": None, "analyzer_detail": { "analyzer_type": "hemispherical", - "radius": 150, - "name": "Specs PHOIBOS 150", + "analyzer_radius": 150, + "analyzer_name": "Specs PHOIBOS 150", "parallel_deflectors": False, "perpendicular_deflectors": False, }, @@ -160,10 +160,10 @@ class TestMetadata: "probe": None, "probe_detail": None, "analyzer_detail": { - "name": "Scienta R8000", + "analyzer_name": "Scienta R8000", "parallel_deflectors": False, "perpendicular_deflectors": False, - "radius": np.nan, + "analyzer_radius": np.nan, "analyzer_type": "hemispherical", }, }, @@ -249,8 +249,8 @@ class TestMetadata: "probe_detail": None, "analyzer_detail": { "analyzer_type": "hemispherical", - "radius": np.nan, - "name": "Scienta R4000", + "analyzer_radius": np.nan, + "analyzer_name": "Scienta R4000", "parallel_deflectors": False, "perpendicular_deflectors": True, },