Skip to content

Commit

Permalink
Add wrapper to determine whether to return a metric or a physio object
Browse files Browse the repository at this point in the history
  • Loading branch information
maestroque committed Aug 10, 2024
1 parent 00b0f56 commit bc300f0
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 25 deletions.
40 changes: 30 additions & 10 deletions phys2denoise/metrics/cardiac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
from ..due import due
from .responses import crf
from .utils import apply_function_in_sliding_window as afsw
from .utils import convolve_and_rescale
from .utils import convolve_and_rescale, return_physio_or_metric


def _cardiac_metrics(
data, metric, fs=None, peaks=None, window=6, central_measure="mean"
data,
metric,
fs=None,
peaks=None,
window=6,
central_measure="mean",
return_physio=False,
):
"""
Compute cardiac metrics.
Expand Down Expand Up @@ -146,6 +152,7 @@ def _cardiac_metrics(


@due.dcite(references.CHANG_CUNNINGHAM_GLOVER_2009)
@return_physio_or_metric()
@physio.make_operation()
def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"):
"""
Expand Down Expand Up @@ -200,14 +207,19 @@ def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"):
Biology Society (EMBC), doi: 10.1109/EMBC.2016.7591347.
"""
data, hr = _cardiac_metrics(
data, metric="hr", fs=fs, peaks=peaks, window=6, central_measure="mean"
data,
metric="hr",
fs=fs,
peaks=peaks,
window=window,
central_measure=central_measure,
)
data._computed_metrics["heart_rate"] = dict(metric=hr, has_lags=False)

return data, hr


@due.dcite(references.PINHERO_ET_AL_2016)
@return_physio_or_metric()
@physio.make_operation()
def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure="mean"):
"""
Expand Down Expand Up @@ -260,14 +272,19 @@ def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure=
Biology Society (EMBC), doi: 10.1109/EMBC.2016.7591347.
"""
data, hrv = _cardiac_metrics(
data, metric="hrv", fs=None, peaks=None, window=6, central_measure="std"
data,
metric="hrv",
fs=fs,
peaks=peaks,
window=window,
central_measure=central_measure,
)
data._computed_metrics["heart_rate_variability"] = dict(metric=hrv)

return data, hrv


@due.dcite(references.CHEN_2020)
@return_physio_or_metric()
@physio.make_operation()
def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="mean"):
"""
Expand Down Expand Up @@ -313,13 +330,18 @@ def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="me
vol. 213, pp. 116707, 2020.
"""
data, hbi = _cardiac_metrics(
data, metric="hbi", fs=None, peaks=None, window=6, central_measure="mean"
data,
metric="hbi",
fs=fs,
peaks=peaks,
window=window,
central_measure=central_measure,
)
data._computed_metrics["heart_beat_interval"] = dict(metric=hbi)

return data, hbi


@return_physio_or_metric()
@physio.make_operation()
def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None):
"""Calculate cardiac phase from cardiac peaks.
Expand Down Expand Up @@ -402,6 +424,4 @@ def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None):
) / (t2 - t1)
phase_card[:, i_slice] = phase_card_crSlice

data._computed_metrics["cardiac_phase"] = dict(metric=phase_card)

return data, phase_card
14 changes: 6 additions & 8 deletions phys2denoise/metrics/chest_belt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from ..due import due
from .responses import rrf
from .utils import apply_function_in_sliding_window as afsw
from .utils import convolve_and_rescale, rms_envelope_1d
from .utils import convolve_and_rescale, return_physio_or_metric, rms_envelope_1d


@due.dcite(references.BIRN_2006)
@return_physio_or_metric()
@physio.make_operation()
def respiratory_variance_time(
data, fs=None, peaks=None, troughs=None, lags=(0, 4, 8, 12)
Expand Down Expand Up @@ -106,11 +107,11 @@ def respiratory_variance_time(
)
rvt_lags[:, ind] = temp_rvt

data._computed_metrics["rvt"] = dict(metric=rvt_lags, has_lags=True)
return data, rvt_lags


@due.dcite(references.POWER_2018)
@return_physio_or_metric()
@physio.make_operation()
def respiratory_pattern_variability(data, window):
"""Calculate respiratory pattern variability.
Expand Down Expand Up @@ -154,11 +155,11 @@ def respiratory_pattern_variability(data, window):
# Calculate standard deviation
rpv_val = np.std(rpv_upper_env)

data._computed_metrics["rpv"] = dict(metric=rpv_val)
return data, rpv_val


@due.dcite(references.POWER_2020)
@return_physio_or_metric()
@physio.make_operation()
def env(data, fs=None, window=10):
"""Calculate respiratory pattern variability across a sliding window.
Expand Down Expand Up @@ -238,11 +239,11 @@ def _respiratory_pattern_variability(data, window):
)
env_arr[np.isnan(env_arr)] = 0.0

data._computed_metrics["env"] = dict(metric=env_arr)
return data, env_arr


@due.dcite(references.CHANG_GLOVER_2009)
@return_physio_or_metric()
@physio.make_operation()
def respiratory_variance(data, fs=None, window=6):
"""Calculate respiratory variance.
Expand Down Expand Up @@ -304,11 +305,10 @@ def respiratory_variance(data, fs=None, window=6):
# Convolve with rrf
rv_out = convolve_and_rescale(rv_arr, rrf(data.fs), rescale="zscore")

data._computed_metrics["respiratory_variance"] = dict(metric=rv_out)

return data, rv_out


@return_physio_or_metric()
@physio.make_operation()
def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None):
"""Calculate respiratory phase from respiratory signal.
Expand Down Expand Up @@ -374,6 +374,4 @@ def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None):

phase_resp[:, i_slice] = phase_resp_crSlice

data._computed_metrics["respiratory_phase"] = dict(metric=phase_resp)

return data, phase_resp
38 changes: 38 additions & 0 deletions phys2denoise/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Miscellaneous utility functions for metric calculation."""
import functools
import logging

import numpy as np
from loguru import logger
from numpy.lib.stride_tricks import sliding_window_view as swv
from physutils.physio import Physio
from scipy.interpolate import interp1d
from scipy.stats import zscore

Expand Down Expand Up @@ -332,3 +335,38 @@ def export_metric(
)

return fileprefix


def return_physio_or_metric(*, return_physio=True):
"""
Decorator to check if the input is a Physio object.
Parameters
----------
func : function
Function to be decorated
Returns
-------
function
Decorated function
"""

def determine_return_type(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
physio, metric = func(*args, **kwargs)
if isinstance(args[0], Physio):
physio._computed_metrics[func.__name__] = dict(
metric=metric, args=kwargs
)
if return_physio:
return physio, metric
else:
return metric
else:
return metric

return wrapper

return determine_return_type
2 changes: 1 addition & 1 deletion phys2denoise/tests/test_metrics_cardiac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_cardiac_phase_smoke():
slice_timings = np.linspace(0, t_r, 22)[1:-1]
peaks = np.array([0.534, 0.577, 10.45, 20.66, 50.55, 90.22])
data = np.zeros(peaks.shape)
_, card_phase = cardiac.cardiac_phase(
card_phase = cardiac.cardiac_phase(
data,
peaks=peaks,
fs=sample_rate,
Expand Down
8 changes: 4 additions & 4 deletions phys2denoise/tests/test_metrics_chest_belt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_respiratory_phase_smoke():
slice_timings = np.linspace(0, t_r, 22)[1:-1]
n_samples = int(np.rint((n_scans * t_r) * sample_rate))
resp = np.random.normal(size=n_samples)
_, resp_phase = chest_belt.respiratory_phase(
resp_phase = chest_belt.respiratory_phase(
resp,
fs=sample_rate,
slice_timings=slice_timings,
Expand All @@ -46,7 +46,7 @@ def test_respiratory_pattern_variability_smoke():
n_samples = 2000
resp = np.random.normal(size=n_samples)
window = 50
_, rpv_val = chest_belt.respiratory_pattern_variability(resp, window)
rpv_val = chest_belt.respiratory_pattern_variability(resp, window)
assert isinstance(rpv_val, float)


Expand All @@ -56,7 +56,7 @@ def test_env_smoke():
resp = np.random.normal(size=n_samples)
samplerate = 1 / 0.01
window = 6
_, env_arr = chest_belt.env(resp, fs=samplerate, window=window)
env_arr = chest_belt.env(resp, fs=samplerate, window=window)
assert env_arr.ndim == 1
assert env_arr.shape == (n_samples,)

Expand All @@ -67,6 +67,6 @@ def test_respiratory_variance_smoke():
resp = np.random.normal(size=n_samples)
samplerate = 1 / 0.01
window = 6
_, rv_arr = chest_belt.respiratory_variance(resp, fs=samplerate, window=window)
rv_arr = chest_belt.respiratory_variance(resp, fs=samplerate, window=window)
assert rv_arr.ndim == 2
assert rv_arr.shape == (n_samples, 2)
4 changes: 2 additions & 2 deletions phys2denoise/tests/test_rvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_respiratory_variance_time(fake_phys):
phys = peakdet.operations.peakfind_physio(phys)

# TODO: Change to a simpler call once physutils are
# integrated to peakdet
phys, r = respiratory_variance_time(
# integrated to peakdet/prep4phys
r = respiratory_variance_time(
phys.data, fs=phys.fs, peaks=phys.peaks, troughs=phys.troughs
)
assert r is not None
Expand Down

0 comments on commit bc300f0

Please sign in to comment.