Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
🔧  update pyproject.toml

(cherry picked from commit 2e2045dbeb7980958e9fc7b7a78934950d449c6d)
  • Loading branch information
arafune committed Oct 12, 2023
1 parent c73a769 commit fe14b35
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 24 deletions.
21 changes: 17 additions & 4 deletions arpes/preparation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
"""Utility functions used during data preparation and loading."""
from __future__ import annotations

from .axis_preparation import *
from .coord_preparation import *
from .hemisphere_preparation import *
from .tof_preparation import *
from .axis_preparation import (
dim_normalizer,
flip_axis,
normalize_dim,
normalize_total,
sort_axis,
transform_dataarray_axis,
vstack_data,
)
from .coord_preparation import disambiguate_coordinates
from .hemisphere_preparation import stitch_maps
from .tof_preparation import (
build_KE_coords_to_time_coords,
build_KE_coords_to_time_pixel_coords,
process_DLD,
process_SToF,
)
28 changes: 16 additions & 12 deletions arpes/preparation/axis_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ def normalize_dim(
"""
dims: list[str]
dims = [dim_or_dims] if isinstance(dim_or_dims, str) else dim_or_dims
assert isinstance(dims, list)

summed_arr = arr.fillna(arr.mean()).sum([d for d in arr.dims if d not in dims])
normalized_arr = arr / (summed_arr / np.prod(summed_arr.shape))

to_return = xr.DataArray(normalized_arr.values, arr.coords, arr.dims, attrs=arr.attrs)
to_return = xr.DataArray(normalized_arr.values, arr.coords, tuple(arr.dims), attrs=arr.attrs)

if not keep_id and "id" in to_return.attrs:
del to_return.attrs["id"]
Expand All @@ -149,28 +150,29 @@ def normalize_dim(


@update_provenance("Normalize total spectrum intensity")
def normalize_total(data: DataType) -> xr.DataArray:
def normalize_total(data: DataType, *, total_intensity: float = 1000000) -> xr.DataArray:
"""Normalizes data so that the total intensity is 1000000 (a bit arbitrary).
Args:
data(DataType): [TODO:description]
data(DataType): Input ARPES data
total_intensity: value for normalizaiton
Returns:
xr.DataArray
"""
data_array = normalize_to_spectrum(data)
assert isinstance(data_array, xr.DataArray)
return data_array / (data_array.sum(data.dims) / 1000000)
return data_array / (data_array.sum(data.dims) / total_intensity)


def dim_normalizer(dim_name):
def dim_normalizer(dim_name: str) -> Callable[[xr.Dataset | xr.DataArray], xr.DataArray]:
"""Safe partial application of dimension normalization.
Args:
dim_name ([TODO:type]): [TODO:description]
dim_name (str): [TODO:description]
"""

def normalize(arr: xr.DataArray):
def normalize(arr: xr.Dataset | xr.DataArray) -> xr.DataArray:
if dim_name not in arr.dims:
return arr
return normalize_dim(arr, dim_name)
Expand All @@ -179,14 +181,14 @@ def normalize(arr: xr.DataArray):


def transform_dataarray_axis(
f,
func: Callable[..., ...],
old_axis_name: str,
new_axis_name: str,
new_axis: NDArray[np.float_] | xr.DataArray,
dataset: xr.Dataset,
prep_name: Callable[[str], str],
transform_spectra=None,
remove_old=True,
transform_spectra: dict[str, xr.DataArray] | None = None,
remove_old: bool = True,
) -> xr.Dataset:
"""Applies a function onto a DataArray axis.
Expand All @@ -202,8 +204,10 @@ def transform_dataarray_axis(
"""
ds = dataset.copy()
if transform_spectra is None:
# transform *all* DataArrays in the dataset that have old_axis_name in their dimensions
# transform *all* DataArrays in the dataset that have old_axis_name in their dimensions.
# In the standard usage, k is "spectra", v is xr.DataArray
transform_spectra = {k: v for k, v in ds.data_vars.items() if old_axis_name in v.dims}
assert isinstance(transform_spectra, dict)

ds.coords[new_axis_name] = new_axis

Expand All @@ -217,7 +221,7 @@ def transform_dataarray_axis(
new_dims = list(dr.dims)
new_dims[old_axis] = new_axis_name

g = functools.partial(f, axis=old_axis)
g = functools.partial(func, axis=old_axis)
output = geometric_transform(dr.values, g, output_shape=shape, output="f", order=1)

new_coords = dict(dr.coords)
Expand Down
6 changes: 3 additions & 3 deletions arpes/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import time
from dataclasses import dataclass, field
from logging import INFO, Formatter, StreamHandler, getLogger
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -15,8 +15,8 @@
__all__ = [
"traceable",
]

LOGLEVEL = INFO
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[0]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
Expand Down
16 changes: 16 additions & 0 deletions arpes/utilities/widgets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Wraps Qt widgets in ones which use rx for signaling, Conrad's personal preference."""
from __future__ import annotations

from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -35,6 +36,18 @@
"SubjectiveTextEdit",
)

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[0]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


class SubjectiveComboBox(QComboBox):
"""A QComboBox using rx instead of signals."""
Expand Down Expand Up @@ -188,6 +201,9 @@ class SubjectiveCheckBox(QCheckBox):

def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None:
"""Wrap signals in ``rx.BehaviorSubject``s."""
if kwargs:
for k, v in kwargs.items():
logger.debug(f"unused kwargs: key: {k}, value{v}")
super().__init__(*args)
self.subject = BehaviorSubject(self.checkState())
self.stateChanged.connect(self.subject.on_next)
Expand Down
2 changes: 1 addition & 1 deletion arpes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _open_path(p: Path | str) -> None:
if "win" in sys.platform:
subprocess.Popen(rf"explorer {p}")
else:
print(p) # noqa: T201
print(p)


@with_workspace
Expand Down
3 changes: 2 additions & 1 deletion arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@

ANGLE_VARS = ("alpha", "beta", "chi", "psi", "phi", "theta")

LOGLEVEL = (DEBUG, INFO)[1]
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
Expand Down
5 changes: 4 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
PyARPES
=======
**Non maintainer update**



**December 2020, V3 Release**: The current relase focuses on improving
usage and workflow for less experienced Python users, lifting version
Expand Down Expand Up @@ -227,4 +230,4 @@ design, Michael Khachatrian
dev-guide
api
CHANGELOG


6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ ignore = [
"PD011", # pandas-use-of-dot-values
"FBT002", # boolean-default-value-in-function-definition
"FIX002", # line-contains-todo (FIX002)#
"G004", # logging-f-string
]
select = ["ALL"]
line-length = 100
Expand All @@ -93,9 +94,10 @@ exclude = ["scripts", "docs", "conda"]


[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"] # unused-import

"__init__.py" = ["F401"] # unused-import
"arpes/__init__.py" = ["T201"] # print used
"arpes/experiment/__init__.py" = ["ALL"]
"arpes/workflow.py" = ["T201", "T203"]

# Bokeh based plotting tools will be removed.
"arpes/plotting/band_tool.py" = ["ALL"]
Expand Down

0 comments on commit fe14b35

Please sign in to comment.