Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 21, 2024
1 parent 27d0795 commit 2690bd6
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 21 deletions.
7 changes: 4 additions & 3 deletions src/arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class AnalyzerInfo(TypedDict, total=False):
see analyzer_info in xarray_extensions.py (around line# 1490)
"""

name: str
analyzer: str
analyzer_name: str
lens_mode: str | None
lens_mode_name: str | None
acquisition_mode: str
Expand All @@ -200,7 +201,7 @@ class AnalyzerInfo(TypedDict, total=False):
work_function: float
#
is_slit_vertical: bool
radius: str | float
analyzer_radius: str | float


class _PumpInfo(TypedDict, total=False):
Expand Down Expand Up @@ -337,7 +338,7 @@ class ExperimentInfo(
AnalyzerInfo,
total=False,
):
pass
analyzer_detail: AnalyzerInfo


class ARPESAttrs(Spectrometer, LightSourceInfo, SampleInfo, total=False):
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/fits/fit_models/x_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _real_data_etc_from_xarray(
real_data, flat_data = data.values, data.values
assert len(real_data.shape) == self.n_dims
coord_values = {}
new_dim_order: list[str]
new_dim_order: list[str] = []
if self.n_dims == 1:
coord_values["x"] = data.coords[next(iter(data.indexes))].values
else:
Expand Down
8 changes: 4 additions & 4 deletions src/arpes/plotting/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if TYPE_CHECKING:
import lmfit as lf
from numpy.typing import NDArray


def plot_fit(model_result: lf.Model, ax: Axes | None = None) -> None:
Expand Down Expand Up @@ -49,11 +50,10 @@ def plot_fit(model_result: lf.Model, ax: Axes | None = None) -> None:
ax.set_xlim(left=np.min(x), right=np.max(x))


def plot_fits(model_results: list[lf.Model], ax: Axes | None = None) -> None:
def plot_fits(model_results: list[lf.Model], axs: NDArray[np.object_] | None = None) -> None:
"""Plots several fits onto a grid of axes."""
n_results = len(model_results)
if ax is None:
_fig, ax, _ax_extra = simple_ax_grid(n_results, sharex="col", sharey="row")
axs = axs if axs else simple_ax_grid(n_results, sharex="col", sharey="row")[1]

for axi, model_result in zip(ax, model_results, strict=False):
for axi, model_result in zip(axs, model_results, strict=False):
plot_fit(model_result, ax=axi)
3 changes: 1 addition & 2 deletions src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import UTC
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Unpack
from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -316,7 +316,6 @@ def simple_ax_grid(
ax, ax_rest = ax.ravel()[:n_axes], ax.ravel()[n_axes:]
for axi in ax_rest:
invisible_axes(axi)

return fig, ax, ax_rest


Expand Down
1 change: 1 addition & 0 deletions src/arpes/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Provenance(TypedDict, total=False):

VERSION: str
jupyter_notebook_name: str
name: str

record: Provenance
jupyter_context: list[str]
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/utilities/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_full_notebook_information() -> NoteBookInfomation | None:
return None


def get_notebook_name() -> str | None:
def get_notebook_name() -> str:
"""Gets the unqualified name of the running Jupyter notebook if not password protected.
As an example, if you were running a notebook called "Doping-Analysis.ipynb"
Expand All @@ -142,7 +142,7 @@ def get_notebook_name() -> str | None:
jupyter_info = get_full_notebook_information()
if jupyter_info:
return Path(jupyter_info["session"]["notebook"]["name"]).stem
return None
return ""


def generate_logfile_path() -> Path:
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def go_to_figures() -> None:
_open_path(path)


def get_running_context() -> tuple[Incomplete, Path]:
def get_running_context() -> tuple[str, Path]:
return get_notebook_name(), Path.cwd()


Expand Down
7 changes: 5 additions & 2 deletions src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,11 +1590,14 @@ def undulator_info(self) -> dict[str, str | float | None]:
def analyzer_detail(self) -> AnalyzerInfo:
"""Details about the analyzer, its capabilities, and metadata."""
return {
"name": self._obj.attrs.get("analyzer_name", self._obj.attrs.get("analyzer", "")),
"analyzer_name": self._obj.attrs.get(
"analyzer_name",
self._obj.attrs.get("analyzer", ""),
),
"parallel_deflectors": self._obj.attrs.get("parallel_deflectors", False),
"perpendicular_deflectors": self._obj.attrs.get("perpendicular_deflectors", False),
"analyzer_type": self._obj.attrs.get("analyzer_type", ""),
"radius": self._obj.attrs.get("analyzer_radius", np.nan),
"analyzer_radius": self._obj.attrs.get("analyzer_radius", np.nan),
}

@property
Expand Down
12 changes: 6 additions & 6 deletions tests/test_basic_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class TestMetadata:
"probe_detail": None,
"analyzer_detail": {
"analyzer_type": "hemispherical",
"radius": 150,
"name": "Specs PHOIBOS 150",
"analyzer_radius": 150,
"analyzer_name": "Specs PHOIBOS 150",
"parallel_deflectors": False,
"perpendicular_deflectors": False,
},
Expand Down Expand Up @@ -160,10 +160,10 @@ class TestMetadata:
"probe": None,
"probe_detail": None,
"analyzer_detail": {
"name": "Scienta R8000",
"analyzer_name": "Scienta R8000",
"parallel_deflectors": False,
"perpendicular_deflectors": False,
"radius": np.nan,
"analyzer_radius": np.nan,
"analyzer_type": "hemispherical",
},
},
Expand Down Expand Up @@ -249,8 +249,8 @@ class TestMetadata:
"probe_detail": None,
"analyzer_detail": {
"analyzer_type": "hemispherical",
"radius": np.nan,
"name": "Scienta R4000",
"analyzer_radius": np.nan,
"analyzer_name": "Scienta R4000",
"parallel_deflectors": False,
"perpendicular_deflectors": True,
},
Expand Down

0 comments on commit 2690bd6

Please sign in to comment.