Skip to content

Commit

Permalink
🔨 Refactoring to reducde complexity.
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 12, 2024
1 parent e55d386 commit 007b859
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 34 deletions.
57 changes: 24 additions & 33 deletions src/arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ def bootstrap_intensity_polarization(data: xr.Dataset, n: int = 100) -> xr.Datas


def bootstrap(
fn: Callable,
fn: Callable[..., xr.Dataset | xr.DataArray],
skip: set[int] | list[int] | None = None,
resample_method: str | None = None,
) -> Callable:
) -> Callable[..., xr.DataArray | xr.Dataset]:
"""Produces function which performs a bootstrap of an arbitrary function by sampling.
This is a functor which takes a function operating on plain data and produces one which
Expand All @@ -341,10 +341,7 @@ def bootstrap(
Returns:
A function which vectorizes the output of the input function `fn` over samples.
"""
if skip is None:
skip = []

skip = set(skip)
skip = set(skip) if skip else set()

if resample_method is None:
resample_fn = resample
Expand All @@ -363,35 +360,20 @@ def bootstrapped(
for i, arg in enumerate(args)
if isinstance(arg, xr.DataArray | xr.Dataset) and i not in skip
]
data_is_arraylike: bool = False

runs = []

def get_label(i: int) -> str:
if isinstance(args[i], xr.Dataset):
return "xr.Dataset: [{}]".format(", ".join(args[i].data_vars.keys()))
if args[i].name:
return args[i].name
try:
return args[i].attrs["id"]
except KeyError:
return "Label-less DataArray"

msg = "Resampling args: {}".format(",".join([get_label(i) for i in resample_indices]))
msg = "Resampling args: "
msg += f"{','.join([_get_label_from_args(args, i) for i in resample_indices])}"
logger.info(msg)

# examine kwargs to determine which to resample
resample_kwargs = [
k for k, v in kwargs.items() if isinstance(v, xr.DataArray) and k not in skip
]
msg = "Resampling kwargs: {}".format(",".join(resample_kwargs))
logger.info(msg)

logger.info(
f"Resampling kwargs: {','.join(resample_kwargs)}"
"Fair warning 1: Make sure you understand whether"
" it is appropriate to resample your data.",
)
logger.info(
" it is appropriate to resample your data."
"Fair warning 2: Ensure that the data to resample is in a DataArray and not a Dataset",
)

Expand All @@ -404,16 +386,25 @@ def get_label(i: int) -> str:
new_kwargs[k] = resample_fn(kwargs[k], prior_adjustment=prior_adjustment)

run = fn(*new_args, **new_kwargs)
if isinstance(run, xr.DataArray | xr.Dataset):
data_is_arraylike = True
runs.append(run)

if data_is_arraylike:
for i, run in enumerate(runs):
run = run.assign_coords(bootstrap=i)
return xr.concat(
[
run.assign_coords(bootstrap=i)
for i, run in enumerate(runs)
if isinstance(run, xr.DataArray | xr.Dataset)
],
)

return xr.concat(runs, dim="bootstrap")
return functools.wraps(fn)(bootstrapped)

return runs

return functools.wraps(fn)(bootstrapped)
def _get_label_from_args(args: tuple[Any, ...], i: int) -> str:
if isinstance(args[i], xr.Dataset):
return "xr.Dataset: [{}]".format(", ".join(args[i].data_vars.keys()))
if args[i].name:
return args[i].name
try:
return args[i].attrs["id"]
except KeyError:
return "Label-less DataArray"
3 changes: 2 additions & 1 deletion src/arpes/preparation/tof_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .axis_preparation import transform_dataarray_axis
from typing import TYPE_CHECKING
from arpes.constants import BARE_ELECTRON_MASS

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -117,7 +118,7 @@ def build_KE_coords_to_time_pixel_coords(
"""Constructs a coordinate conversion function from kinetic energy to time pixels."""
conv = (
dataset.S.spectrometer["mstar"]
* (9.11e6)
* (BARE_ELECTRON_MASS * 1e37)
* 0.5
* (dataset.S.spectrometer["length"] ** 2)
/ 1.6
Expand Down

0 comments on commit 007b859

Please sign in to comment.