Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 3, 2024
1 parent 8fb8f34 commit 8a5f03f
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,22 +1032,20 @@ def find_spectrum_angular_edges_full(

delta = self._obj.G.stride(generic_dim_names=False)

low_edges = (
np.array(low_edges) * delta[angular_dim] + rebinned.coords[angular_dim].values[0]
)
high_edges = (
np.array(high_edges) * delta[angular_dim] + rebinned.coords[angular_dim].values[0]
return (
np.array(low_edges) * delta[angular_dim] + rebinned.coords[angular_dim].values[0],
np.array(high_edges) * delta[angular_dim] + rebinned.coords[angular_dim].values[0],
rebinned.coords["eV"],
)

return low_edges, high_edges, rebinned.coords["eV"]

def zero_spectrometer_edges(
self,
cut_margin: int = 0,
interp_range: float | None = None,
low: Sequence[float] | None = None,
high: Sequence[float] | None = None,
) -> xr.DataArray | xr.Dataset:
low: Sequence[float] | NDArray[np.float_] | None = None,
high: Sequence[float] | NDArray[np.float_] | None = None,
) -> xr.DataArray:
assert isinstance(self._obj, xr.DataArray)
if low is not None:
assert high is not None
assert len(low) == len(high) == 2 # noqa: PLR2004
Expand Down Expand Up @@ -1093,10 +1091,10 @@ def zero_spectrometer_edges(
other = len(rebinned_eV_coord) - 1
index = len(rebinned_eV_coord) - 2

low = int(np.interp(energy, rebinned_eV_coord, low_edges))
high = int(np.interp(energy, rebinned_eV_coord, high_edges))
copied.values[i, 0:low] = 0
copied.values[i, high:-1] = 0
low_index = int(np.interp(energy, rebinned_eV_coord, low_edges))
high_index = int(np.interp(energy, rebinned_eV_coord, high_edges))
copied.values[i, 0:low_index] = 0
copied.values[i, high_index:-1] = 0

return copied

Expand Down Expand Up @@ -1451,7 +1449,7 @@ def sample_info(self) -> SAMPLEINFO:
"""
sample_info: SAMPLEINFO = {
"id": self._obj.attrs.get("sample_id"),
"name": self._obj.attrs.get("sample_name"),
"sample_name": self._obj.attrs.get("sample_name"),
"source": self._obj.attrs.get("sample_source"),
"reflectivity": self._obj.attrs.get("sample_reflectivity", np.nan),
}
Expand Down Expand Up @@ -1710,7 +1708,7 @@ def dict_to_html(d: Mapping[str, float | str]) -> str:

def _repr_html_full_coords(
self,
coords: dict[str, xr.DataArray],
coords: dict[str, float | xr.DataArray],
) -> str:
significant_coords = {}
for k, v in coords.items():
Expand All @@ -1721,8 +1719,8 @@ def _repr_html_full_coords(
significant_coords[k] = v

def coordinate_dataarray_to_flat_rep(
value: xr.DataArray,
) -> str:
value: xr.DataArray | float,
) -> str | float:
if not isinstance(value, xr.DataArray):
return value
if len(value.dims) == 0:
Expand Down Expand Up @@ -1903,10 +1901,10 @@ def _radian_to_degree(self) -> None:
self.angle_unit = "Degrees"
for angle in ANGLE_VARS:
if angle in self._obj.attrs:
self._obj.attrs[angle] = np.rad2deg(self._obj.attrs.get(angle))
self._obj.attrs[angle] = np.rad2deg(self._obj.attrs.get(angle, np.nan))
if angle + "_offset" in self._obj.attrs:
self._obj.attrs[angle + "_offset"] = np.rad2deg(
self._obj.attrs.get(angle + "_offset"),
self._obj.attrs.get(angle + "_offset", np.nan),
)
if angle in self._obj.coords:
self._obj.coords[angle] = np.rad2deg(self._obj.coords[angle])
Expand Down Expand Up @@ -2063,7 +2061,7 @@ def _simple_spectrum_reference_plot(

return fancy_dispersion(self._obj, **kwargs)

def cut_nan_coords(self) -> xr.DataArray:
def cut_nan_coords(self) -> xr.DataArray | xr.Dataset:
"""Selects data where coordinates are not `nan`.
Returns (xr.DataArray):
Expand All @@ -2077,7 +2075,7 @@ def cut_nan_coords(self) -> xr.DataArray:
slices[cname] = slice(None, end_ind)
except IndexError:
pass
return self._obj.isel(**slices)
return self._obj.isel(slices)

def reference_plot(
self,
Expand Down

0 comments on commit 8a5f03f

Please sign in to comment.