Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 4, 2024
1 parent a8a7eb5 commit 569abf6
Show file tree
Hide file tree
Showing 16 changed files with 162 additions and 95 deletions.
34 changes: 17 additions & 17 deletions arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,17 @@ class ANALYZERINFO(TypedDict, total=False):
lens_mode_name: str | None
acquisition_mode: str
pass_energy: float
slit_shape: str
slit_shape: str | None
slit_width: float
slit_number: str | int
lens_table: None
analyzer_type: str
analyzer_type: str | None
mcp_voltage: float
work_function: float
#
analyzer_radius: float
analyzer: str
analyzer_name: str
analyzer: str | None
analyzer_name: str | None
parallel_deflectors: bool
perpendicular_deflectors: bool

Expand All @@ -174,10 +174,10 @@ class _PUMPINFO(TypedDict, total=False):
pump_spot_size: float | tuple[float, float]
pump_spot_size_x: float
pump_spot_size_y: float
pump_profile: None
pump_profile: Incomplete
pump_linewidth: float
pump_duration: float
pump_polarization: str | tuple[float | None, float | None]
pump_polarization: str | tuple[float, float]
pump_polarization_theta: float
pump_polarization_alpha: float

Expand All @@ -198,7 +198,7 @@ class _PROBEINFO(TypedDict, total=False):
probe_profile: None
probe_linewidth: float
probe_duration: float
probe_polarization: str | tuple[float | None, float | None]
probe_polarization: str | tuple[float, float]
probe_polarization_theta: float
probe_polarization_alpha: float

Expand All @@ -211,21 +211,21 @@ class _BEAMLINEINFO(TypedDict, total=False):

hv: float | xr.DataArray
linewidth: float
photon_polarization: tuple[float | None, float | None]
photon_polarization: tuple[float, float]
undulation_info: Incomplete
repetition_rate: float
beam_current: float
entrance_slit: float
exit_slit: float
monochrometer_info: dict[str, None | float]
entrance_slit: float | str
exit_slit: float | str
monochrometer_info: dict[str, float]


class LIGHTSOURCEINFO(_PROBEINFO, _PUMPINFO, _BEAMLINEINFO, total=False):
polarization: float | tuple[float, float] | str
photon_flux: float
photocurrent: float
probe: None
probe_detail: None
probe: Incomplete
probe_detail: Incomplete


class SAMPLEINFO(TypedDict, total=False):
Expand All @@ -234,9 +234,9 @@ class SAMPLEINFO(TypedDict, total=False):
see sample_info in xarray_extensions
"""

id: int | str
sample_name: str
source: str
id: int | str | None
sample_name: str | None
source: str | None
reflectivity: float


Expand All @@ -258,7 +258,7 @@ class DAQINFO(TypedDict, total=False):
see daq_info in xarray_extensions.py
"""

daq_type: str
daq_type: str | None
region: str | None
region_name: str | None
center_energy: float
Expand Down
3 changes: 2 additions & 1 deletion arpes/plotting/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mpl_toolkits.mplot3d import Axes3D

from arpes.constants import TWO_DIMENSION
from arpes.utilities.conversion.forward import convert_coordinates_to_kspace_forward

from .utils import name_for_dim, unit_for_dim

Expand Down Expand Up @@ -181,6 +180,8 @@ def annotate_cuts(
include_text_labels: Whether to include text labels
kwargs: Defines the coordinates of the cut location
"""
from arpes.utilities.conversion.forward import convert_coordinates_to_kspace_forward

converted_coordinates = convert_coordinates_to_kspace_forward(data)
assert converted_coordinates, xr.Dataset | xr.DataArray
assert len(plotted_axes) == TWO_DIMENSION
Expand Down
1 change: 0 additions & 1 deletion arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ class LabeledFermiSurfaceParam(TypedDict, total=False):
include_symmetry_points: bool
include_bz: bool
fermi_energy: float
out: str


@save_plot_provenance
Expand Down
20 changes: 12 additions & 8 deletions arpes/plotting/qt_tool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Provides a Qt based implementation of Igor's ImageTool."""

# pylint: disable=import-error
from __future__ import annotations

import contextlib
import warnings
import weakref
from collections.abc import Sequence
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, reveal_type

Expand All @@ -31,7 +33,9 @@
from .BinningInfoWidget import BinningInfoWidget

if TYPE_CHECKING:
import xarray as xr
from _typeshed import Incomplete
from PySide6.QtCore import QEvent
from PySide6.QtWidgets import QWidget

from arpes._typing import DataType
Expand Down Expand Up @@ -113,20 +117,20 @@ def compile_key_bindings(self) -> list[KeyBinding]:
),
]

def center_cursor(self, event) -> None:
def center_cursor(self, event: QEvent) -> None:
logger.debug(f"method: center_cursor {event!s}")
self.app().center_cursor()

def transpose_roll(self, event) -> None:
def transpose_roll(self, event: QEvent) -> None:
logger.debug(f"method: transpose_roll {event!s}")
self.app().transpose_to_front(-1)

def transpose_swap(self, event) -> None:
def transpose_swap(self, event: QEvent) -> None:
logger.debug(f"method: transpose_swap {event!s}")
self.app().transpose_to_front(1)

@staticmethod
def _update_scroll_delta(delta, event: QtGui.QKeyEvent) -> tuple:
def _update_scroll_delta(delta: tuple[float, ...], event: QtGui.QKeyEvent) -> tuple:
logger.debug(f"method: _update_scroll_delta {event!s}")
if event.nativeModifiers() & 1: # shift key
delta = (delta[0], delta[1] * 5)
Expand Down Expand Up @@ -190,7 +194,7 @@ class QtTool(SimpleApp):
def __init__(self) -> None:
"""Initialize attributes to safe empty values."""
super().__init__()
self.data = None
self.data: xr.Dataset | xr.DataArray

self.content_layout = None
self.main_layout: QtWidgets.QGridLayout | None = None
Expand Down Expand Up @@ -419,13 +423,13 @@ def safe_slice(vlow: float, vhigh: float, axis: int = 0) -> slice:
),
)
if isinstance(reactive.view, DataArrayImageView):
image_data = self.data.isel(**select_coord)
image_data = self.data.isel(select_coord)
if select_coord:
image_data = image_data.mean(list(select_coord.keys()))
reactive.view.setImage(image_data, keep_levels=keep_levels)

elif isinstance(reactive.view, pg.PlotWidget):
for_plot = self.data.isel(**select_coord)
for_plot = self.data.isel(select_coord)
if select_coord:
for_plot = for_plot.mean(list(select_coord.keys()))

Expand All @@ -449,7 +453,7 @@ def safe_slice(vlow: float, vhigh: float, axis: int = 0) -> slice:
except IndexError:
pass

def construct_axes_tab(self) -> tuple[QtWidgets, list[AxisInfoWidget]]:
def construct_axes_tab(self) -> tuple[QWidget, list[AxisInfoWidget]]:
"""Controls for axis order and transposition."""
inner_items = [
AxisInfoWidget(axis_index=i, root=weakref.ref(self)) for i in range(len(self.data.dims))
Expand Down
2 changes: 1 addition & 1 deletion arpes/plotting/stack_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def stack_dispersion_plot( # noqa: PLR0913
negate(bool): _description_
**kwargs:
set figsize to change the default figisize=(7,7)
set title, if not specified the attrs[description] (or S.scan_name) is used.
set title, if not specified the attrs[description] (or S.label) is used.
other kwargs is passed to ax.plot (or ax.scatter). Can set linewidth/s etc., here.
"""
data_arr, stack_axis, other_axis = _rebinning(
Expand Down
1 change: 1 addition & 0 deletions arpes/utilities/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def normalize_to_spectrum(data: DataType | str) -> xr.DataArray:
"""Tries to extract the actual ARPES spectrum from a dataset containing other variables."""
import arpes.xarray_extensions # noqa: F401
from arpes.io import load_data

if isinstance(data, xr.Dataset):
Expand Down
35 changes: 27 additions & 8 deletions arpes/utilities/qt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Infrastructure code for Qt based analysis tools."""

from __future__ import annotations

import functools
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from multiprocessing import Process
from typing import TYPE_CHECKING

import dill
import pyqtgraph as pg
from pyqtgraph import ViewBox
from PySide6.QtCore import QCoreApplication
from PySide6.QtWidgets import QWidget

from arpes._typing import xr_types

Expand Down Expand Up @@ -35,6 +39,18 @@
"run_tool_in_daemon_process",
)

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


def run_tool_in_daemon_process(tool_handler: Callable) -> Callable:
"""Starts a Qt based tool as a daemon process.
Expand Down Expand Up @@ -89,15 +105,15 @@ def remove_dangling_viewboxes() -> None:
* ViewBox.AllViews
* ViewBox.NamedViews
"""
import sip
import sipbuild # TODO: CHECK.

for_deletion = set()

# In each case, we need to coerce the collection to
# a list before we iterate because we are modifying the
# underlying collection
for v in list(ViewBox.AllViews):
if sip.isdeleted(v):
if sipbuild.isdeleted(v):
# first remove it from the ViewBox references
# and then we will delete it later to prevent an
# error
Expand All @@ -107,7 +123,7 @@ def remove_dangling_viewboxes() -> None:
for vname in list(ViewBox.NamedViews):
v = ViewBox.NamedViews[vname]

if sip.isdeleted(v):
if sipbuild.isdeleted(v):
for_deletion.add(v)
del ViewBox.NamedViews[vname]

Expand All @@ -128,25 +144,28 @@ def init_from_app(self, app: QApplication) -> None:

self._inited = True
dpis = [screen.physicalDotsPerInch() for screen in app.screens()]
self.screen_dpi = sum(dpis) / len(dpis)
self.screen_dpi = int(sum(dpis) / len(dpis))

def apply_settings_to_app(self, app: QApplication) -> None:
# Adjust the font size based on screen DPI
font = app.font()
font.setPointSize(self.inches_to_px(0.1))
logger.debug(f"Type of app {type(app)}")
font_size = self.inches_to_px(0.1)
assert isinstance(font_size, int)
font.setPointSize(font_size)
app.instance().setFont(font)

def inches_to_px(
self,
arg: float | tuple[float, ...],
) -> int | Generator[int, None, None]:
) -> int | tuple[int, ...]:
if isinstance(
arg,
int | float,
):
return int(self.screen_dpi * arg)

return (int(x * self.screen_dpi) for x in arg)
return tuple(int(x * self.screen_dpi) for x in arg)

def setup_pyqtgraph(self) -> None:
"""Does any patching required on PyQtGraph and configures options."""
Expand Down Expand Up @@ -177,7 +196,7 @@ def patchedLinkedViewChanged(
We also don't handle inverted axes for now.
"""
if self.linksBlocked or view is None:
if view is None:
return

vr = view.viewRect()
Expand Down
Loading

0 comments on commit 569abf6

Please sign in to comment.