Skip to content

Commit

Permalink
Metric or Physio returning optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
maestroque committed Aug 11, 2024
1 parent 181304f commit 43107db
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 13 deletions.
17 changes: 12 additions & 5 deletions phys2denoise/metrics/cardiac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _cardiac_metrics(
peaks=None,
window=6,
central_measure="mean",
return_physio=False,
**kwargs,
):
"""
Compute cardiac metrics.
Expand Down Expand Up @@ -154,7 +154,7 @@ def _cardiac_metrics(
@due.dcite(references.CHANG_CUNNINGHAM_GLOVER_2009)
@return_physio_or_metric()
@physio.make_operation()
def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"):
def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean", **kwargs):
"""
Compute average heart rate (HR) in a sliding window.
Expand Down Expand Up @@ -213,6 +213,7 @@ def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"):
peaks=peaks,
window=window,
central_measure=central_measure,
**kwargs,
)

return data, hr
Expand All @@ -221,7 +222,9 @@ def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"):
@due.dcite(references.PINHERO_ET_AL_2016)
@return_physio_or_metric()
@physio.make_operation()
def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure="mean"):
def heart_rate_variability(
data, fs=None, peaks=None, window=6, central_measure="mean", **kwargs
):
"""
Compute average heart rate variability (HRV) in a sliding window.
Expand Down Expand Up @@ -278,6 +281,7 @@ def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure=
peaks=peaks,
window=window,
central_measure=central_measure,
**kwargs,
)

return data, hrv
Expand All @@ -286,7 +290,9 @@ def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure=
@due.dcite(references.CHEN_2020)
@return_physio_or_metric()
@physio.make_operation()
def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="mean"):
def heart_beat_interval(
data, fs=None, peaks=None, window=6, central_measure="mean", **kwargs
):
"""
Compute average heart beat interval (HBI) in a sliding window.
Expand Down Expand Up @@ -336,14 +342,15 @@ def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="me
peaks=peaks,
window=window,
central_measure=central_measure,
**kwargs,
)

return data, hbi


@return_physio_or_metric()
@physio.make_operation()
def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None):
def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None, **kwargs):
"""Calculate cardiac phase from cardiac peaks.
Assumes that timing of cardiac events are given in same units
Expand Down
10 changes: 5 additions & 5 deletions phys2denoise/metrics/chest_belt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@return_physio_or_metric()
@physio.make_operation()
def respiratory_variance_time(
data, fs=None, peaks=None, troughs=None, lags=(0, 4, 8, 12)
data, fs=None, peaks=None, troughs=None, lags=(0, 4, 8, 12), **kwargs
):
"""
Implement the Respiratory Variance over Time (Birn et al. 2006).
Expand Down Expand Up @@ -113,7 +113,7 @@ def respiratory_variance_time(
@due.dcite(references.POWER_2018)
@return_physio_or_metric()
@physio.make_operation()
def respiratory_pattern_variability(data, window):
def respiratory_pattern_variability(data, window, **kwargs):
"""Calculate respiratory pattern variability.
Parameters
Expand Down Expand Up @@ -161,7 +161,7 @@ def respiratory_pattern_variability(data, window):
@due.dcite(references.POWER_2020)
@return_physio_or_metric()
@physio.make_operation()
def env(data, fs=None, window=10):
def env(data, fs=None, window=10, **kwargs):
"""Calculate respiratory pattern variability across a sliding window.
Parameters
Expand Down Expand Up @@ -245,7 +245,7 @@ def _respiratory_pattern_variability(data, window):
@due.dcite(references.CHANG_GLOVER_2009)
@return_physio_or_metric()
@physio.make_operation()
def respiratory_variance(data, fs=None, window=6):
def respiratory_variance(data, fs=None, window=6, **kwargs):
"""Calculate respiratory variance.
Parameters
Expand Down Expand Up @@ -310,7 +310,7 @@ def respiratory_variance(data, fs=None, window=6):

@return_physio_or_metric()
@physio.make_operation()
def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None):
def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None, **kwargs):
"""Calculate respiratory phase from respiratory signal.
Parameters
Expand Down
1 change: 1 addition & 0 deletions phys2denoise/metrics/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def retroicor(
physio_type=None,
fs=None,
cardiac_peaks=None,
**kwargs,
):
"""Compute RETROICOR regressors.
Expand Down
5 changes: 3 additions & 2 deletions phys2denoise/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ def wrapper(*args, **kwargs):
physio._computed_metrics[func.__name__] = dict(
metric=metric, args=kwargs
)
if return_physio:
return physio, metric
return_physio_value = kwargs.get("return_physio", return_physio)
if return_physio_value:
return physio
else:
return metric
else:
Expand Down
18 changes: 17 additions & 1 deletion phys2denoise/tests/test_metrics_cardiac.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,28 @@ def test_cardiac_phase_smoke_physio_obj():
data = np.zeros(peaks.shape)
phys = physio.Physio(data, sample_rate, physio_type="cardiac")
phys._metadata["peaks"] = peaks
phys, card_phase = cardiac.cardiac_phase(

# Test where the physio object is returned
phys = cardiac.cardiac_phase(
phys,
slice_timings=slice_timings,
n_scans=n_scans,
t_r=t_r,
)
assert phys.history[0][0] == "phys2denoise.metrics.cardiac.cardiac_phase"
assert phys.computed_metrics["cardiac_phase"]["metric"].ndim == 2
assert phys.computed_metrics["cardiac_phase"]["metric"].shape == (
n_scans,
slice_timings.size,
)

# Test where the metric is returned
card_phase = cardiac.cardiac_phase(
phys,
slice_timings=slice_timings,
n_scans=n_scans,
t_r=t_r,
return_physio=False,
)
assert card_phase.ndim == 2
assert card_phase.shape == (n_scans, slice_timings.size)

0 comments on commit 43107db

Please sign in to comment.