Skip to content

Commit

Permalink
Merge branch 'daredevil' of github.com:arafune/arpes into daredevil
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Oct 17, 2023
2 parents ad5dad1 + 59eb8fb commit 7e5cfa0
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import itertools
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from collections.abc import Collection, Sequence
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias, Unpack

Expand Down Expand Up @@ -103,7 +103,7 @@
ANGLE_VARS = ("alpha", "beta", "chi", "psi", "phi", "theta")

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
LOGLEVEL = LOGLEVELS[0]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
Expand Down Expand Up @@ -603,8 +603,6 @@ def select_around(
def short_history(self, key: str = "by") -> list:
"""Return the short version of history.
[TODO:description]
Args:
key (str): [TODO:description]
"""
Expand Down Expand Up @@ -1781,13 +1779,24 @@ def dict_to_html(d: dict[str, float | str]) -> str:

def _repr_html_full_coords(
self,
coords: dict[str, dict[str, float | str]],
coords: dict[str, xr.DataArray],
) -> str:
significant_coords = {}
for k, v in coords.items():
if v is None:
continue
if np.any(np.isnan(v)):
continue
significant_coords[str(k)] = v

def coordinate_dataarray_to_flat_rep(
value: xr.DataArray,
) -> str:
if not isinstance(value, xr.DataArray):
return value
if len(value.dims) == 0:
tmp = "<span>{var:.5g}</span>"
return tmp.format(var=value.values)
tmp = "<span>{min:.3g}<strong> to </strong>{max:.3g}"
tmp += "<strong> by </strong>{delta:.3g}</span>"
return tmp.format(
Expand All @@ -1797,7 +1806,7 @@ def coordinate_dataarray_to_flat_rep(
)

return ARPESAccessorBase.dict_to_html(
{k: coordinate_dataarray_to_flat_rep(v) for k, v in coords.items()},
{k: coordinate_dataarray_to_flat_rep(v) for k, v in significant_coords.items()},
)

def _repr_html_spectrometer_info(self) -> str:
Expand All @@ -1810,7 +1819,7 @@ def _repr_html_spectrometer_info(self) -> str:
return ARPESAccessorBase.dict_to_html(ordered_settings)

@staticmethod
def _repr_html_experimental_conditions(conditions: dict) -> str:
def _repr_html_experimental_conditions(conditions: dict[str, str | float | None]) -> str:
transforms = {
"polarization": lambda p: {
"p": "Linear Horizontal",
Expand All @@ -1826,12 +1835,17 @@ def _repr_html_experimental_conditions(conditions: dict) -> str:
"temp": "{} Kelvin".format,
}

def id(x):
def no_change(x: str | float) -> str | float:
return x

return ARPESAccessorBase.dict_to_html(
{k: transforms.get(k, id)(v) for k, v in conditions.items() if v is not None},
)
transformed_dict = {}
for k, v in conditions.items():
if v is None:
continue
if np.isnan(v):
continue
transformed_dict[str(k)] = transforms.get(k, no_change)(v)
return ARPESAccessorBase.dict_to_html(transformed_dict)

def _repr_html_(self) -> str:
skip_data_vars = {
Expand All @@ -1853,20 +1867,19 @@ def _repr_html_(self) -> str:
ax = [ax]

for i, plot_var in enumerate(to_plot):
self._obj[plot_var].plot(ax=ax[i])
self._obj[plot_var].T.plot(ax=ax[i])
fancy_labels(ax[i])
ax[i].set_title(plot_var.replace("_", " "))

remove_colorbars()

elif 1 <= len(self._obj.dims) < 3: # noqa: PLR2004
fig, ax = plt.subplots(1, 1, figsize=(4, 3))
self._obj.plot(ax=ax)
_, ax = plt.subplots(1, 1, figsize=(4, 3))
self._obj.T.plot(ax=ax)
fancy_labels(ax)
ax.set_title("")

remove_colorbars()

wrapper_style = 'style="display: flex; flex-direction: row;"'

try:
Expand All @@ -1876,7 +1889,6 @@ def _repr_html_(self) -> str:
name = "ID: " + str(self._obj.attrs["id"])[:9] + "..."
else:
name = "No name"

warning = ""

if len(self._obj.attrs) < 10: # noqa: PLR2004
Expand All @@ -1891,9 +1903,7 @@ def _repr_html_(self) -> str:
</details>
<details open>
<summary>Full Coordinates</summary>
{self._repr_html_full_coords(
{k: v for k, v in self.full_coords.items() if v is not None},
)}
{self._repr_html_full_coords(self.full_coords)}
</details>
<details open>
<summary>Spectrometer</summary>
Expand Down Expand Up @@ -1981,8 +1991,8 @@ def plot(
Args:
rasterized (bool): if True, rasterized (Not vector) drawing
*args: Pass to xr.DataArray.plot
*kwargs: Pass to xr.DataArray.plot
args: Pass to xr.DataArray.plot
kwargs: Pass to xr.DataArray.plot
"""
if len(self._obj.dims) == 2 and "rasterized" not in kwargs: # noqa: PLR2004
Expand Down Expand Up @@ -2401,23 +2411,31 @@ def drop_nan(self) -> xr.DataArray | xr.Dataset:
mask = np.logical_not(np.isnan(self._obj.values))
return self._obj.isel(**dict([[self._obj.dims[0], mask]]))

def shift_coords(self, dims: tuple[str, ...], shift):
def shift_coords(
self,
dims: tuple[str, ...],
shift: NDArray[np.float_] | float,
) -> xr.DataArray | xr.Dataset:
if self._obj is None:
msg = "Cannot access 'G'"
raise RuntimeError(msg)
if not isinstance(shift, np.ndarray):
shift = np.ones((len(dims),)) * shift

def transform(data):
new_shift = shift
def transform(data: NDArray[np.float_]) -> NDArray[np.float_]:
new_shift: NDArray[np.float_] = shift
for _ in range(len(dims)):
new_shift = np.expand_dims(new_shift, 0)
new_shift = np.expand_dims(new_shift, axis=0)

return data + new_shift

return self.transform_coords(dims, transform)

def scale_coords(self, dims: tuple[str, ...], scale: float | NDArray[np.float_]):
def scale_coords(
self,
dims: tuple[str, ...],
scale: float | NDArray[np.float_],
) -> xr.DataArray | xr.Dataset:
if not isinstance(scale, np.ndarray):
n_dims = len(dims)
scale = np.identity(n_dims) * scale
Expand All @@ -2428,7 +2446,7 @@ def scale_coords(self, dims: tuple[str, ...], scale: float | NDArray[np.float_])

def transform_coords(
self,
dims: list[str],
dims: Collection[str],
transform: NDArray[np.float_] | Callable,
) -> xr.DataArray | xr.Dataset:
"""Transforms the given coordinate values according to an arbitrary function.
Expand All @@ -2454,7 +2472,7 @@ def transform_coords(

copied = self._obj.copy(deep=True)

for d, arr in zip(dims, np.split(transformed, transformed.shape[-1], axis=-1)):
for d, arr in zip(dims, np.split(transformed, transformed.shape[-1], axis=-1), strict=True):
copied.data_vars[d].values = np.squeeze(arr, axis=-1)

return copied
Expand Down

0 comments on commit 7e5cfa0

Please sign in to comment.