Skip to content

Commit

Permalink
🔥 remove function: propagate_statistical_error & derivative
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Nov 22, 2023
1 parent b0caf44 commit a11794e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 38 deletions.
4 changes: 0 additions & 4 deletions arpes/plotting/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from arpes.analysis.statistics import mean_and_deviation
from arpes.bootstrap import bootstrap
from arpes.provenance import save_plot_provenance
from arpes.utilities.math import polarization, propagate_statistical_error

from .tof import scatter_with_std
from .utils import label_for_dim, path_for_plot, polarization_colorbar, savefig
Expand All @@ -35,9 +34,6 @@
)


test_polarization = propagate_statistical_error(polarization)


@save_plot_provenance
def spin_colored_spectrum(
spin_dr: xr.Dataset,
Expand Down
35 changes: 1 addition & 34 deletions arpes/utilities/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Iterable
from collections.abc import Iterable
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -16,44 +16,11 @@
from numpy.typing import NDArray


def derivative(f: Callable[..., float], arg_idx: int = 0) -> float:
"""Defines a simple midpoint derivative."""

def d(*args: Incomplete):
args = list(args)
ref_arg = args[arg_idx]
d = ref_arg / 100
args[arg_idx] = ref_arg + d
high = f(*args)
args[arg_idx] = ref_arg - d
low = f(*args)
return (high - low) / (2 * d)

return d


def polarization(up: NDArray[np.float_], down: NDArray[np.float_]) -> NDArray[np.float_]:
"""The equivalent normalized difference for a two component signal."""
return (up - down) / (up + down)


def propagate_statistical_error(f):
"""To compute a function which propagates statistical error.
It Uses numerical derivatives and sampling.
"""

def compute_propagated_error(*args):
running_sum = 0
for i, arg in enumerate(args):
df_darg_i = derivative(f, i)
running_sum += df_darg_i(*args) ** 2 * arg

return np.sqrt(running_sum)

return compute_propagated_error


def shift_by(
arr: NDArray[np.float_],
value: xr.DataArray | NDArray[np.float_],
Expand Down

0 comments on commit a11794e

Please sign in to comment.