diff --git a/phys2denoise/metrics/cardiac.py b/phys2denoise/metrics/cardiac.py index 06d8b50..72c7e8d 100644 --- a/phys2denoise/metrics/cardiac.py +++ b/phys2denoise/metrics/cardiac.py @@ -7,11 +7,17 @@ from ..due import due from .responses import crf from .utils import apply_function_in_sliding_window as afsw -from .utils import convolve_and_rescale +from .utils import convolve_and_rescale, return_physio_or_metric def _cardiac_metrics( - data, metric, fs=None, peaks=None, window=6, central_measure="mean" + data, + metric, + fs=None, + peaks=None, + window=6, + central_measure="mean", + return_physio=False, ): """ Compute cardiac metrics. @@ -146,6 +152,7 @@ def _cardiac_metrics( @due.dcite(references.CHANG_CUNNINGHAM_GLOVER_2009) +@return_physio_or_metric() @physio.make_operation() def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"): """ @@ -200,14 +207,19 @@ def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"): Biology Society (EMBC), doi: 10.1109/EMBC.2016.7591347. """ data, hr = _cardiac_metrics( - data, metric="hr", fs=fs, peaks=peaks, window=6, central_measure="mean" + data, + metric="hr", + fs=fs, + peaks=peaks, + window=window, + central_measure=central_measure, ) - data._computed_metrics["heart_rate"] = dict(metric=hr, has_lags=False) return data, hr @due.dcite(references.PINHERO_ET_AL_2016) +@return_physio_or_metric() @physio.make_operation() def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure="mean"): """ @@ -260,14 +272,19 @@ def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure= Biology Society (EMBC), doi: 10.1109/EMBC.2016.7591347. """ data, hrv = _cardiac_metrics( - data, metric="hrv", fs=None, peaks=None, window=6, central_measure="std" + data, + metric="hrv", + fs=fs, + peaks=peaks, + window=window, + central_measure=central_measure, ) - data._computed_metrics["heart_rate_variability"] = dict(metric=hrv) return data, hrv @due.dcite(references.CHEN_2020) +@return_physio_or_metric() @physio.make_operation() def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="mean"): """ @@ -313,13 +330,18 @@ def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="me vol. 213, pp. 116707, 2020. """ data, hbi = _cardiac_metrics( - data, metric="hbi", fs=None, peaks=None, window=6, central_measure="mean" + data, + metric="hbi", + fs=fs, + peaks=peaks, + window=window, + central_measure=central_measure, ) - data._computed_metrics["heart_beat_interval"] = dict(metric=hbi) return data, hbi +@return_physio_or_metric() @physio.make_operation() def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None): """Calculate cardiac phase from cardiac peaks. @@ -402,6 +424,4 @@ def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None): ) / (t2 - t1) phase_card[:, i_slice] = phase_card_crSlice - data._computed_metrics["cardiac_phase"] = dict(metric=phase_card) - return data, phase_card diff --git a/phys2denoise/metrics/chest_belt.py b/phys2denoise/metrics/chest_belt.py index 1c20aeb..0a7b800 100644 --- a/phys2denoise/metrics/chest_belt.py +++ b/phys2denoise/metrics/chest_belt.py @@ -9,10 +9,11 @@ from ..due import due from .responses import rrf from .utils import apply_function_in_sliding_window as afsw -from .utils import convolve_and_rescale, rms_envelope_1d +from .utils import convolve_and_rescale, return_physio_or_metric, rms_envelope_1d @due.dcite(references.BIRN_2006) +@return_physio_or_metric() @physio.make_operation() def respiratory_variance_time( data, fs=None, peaks=None, troughs=None, lags=(0, 4, 8, 12) @@ -106,11 +107,11 @@ def respiratory_variance_time( ) rvt_lags[:, ind] = temp_rvt - data._computed_metrics["rvt"] = dict(metric=rvt_lags, has_lags=True) return data, rvt_lags @due.dcite(references.POWER_2018) +@return_physio_or_metric() @physio.make_operation() def respiratory_pattern_variability(data, window): """Calculate respiratory pattern variability. @@ -154,11 +155,11 @@ def respiratory_pattern_variability(data, window): # Calculate standard deviation rpv_val = np.std(rpv_upper_env) - data._computed_metrics["rpv"] = dict(metric=rpv_val) return data, rpv_val @due.dcite(references.POWER_2020) +@return_physio_or_metric() @physio.make_operation() def env(data, fs=None, window=10): """Calculate respiratory pattern variability across a sliding window. @@ -238,11 +239,11 @@ def _respiratory_pattern_variability(data, window): ) env_arr[np.isnan(env_arr)] = 0.0 - data._computed_metrics["env"] = dict(metric=env_arr) return data, env_arr @due.dcite(references.CHANG_GLOVER_2009) +@return_physio_or_metric() @physio.make_operation() def respiratory_variance(data, fs=None, window=6): """Calculate respiratory variance. @@ -304,11 +305,10 @@ def respiratory_variance(data, fs=None, window=6): # Convolve with rrf rv_out = convolve_and_rescale(rv_arr, rrf(data.fs), rescale="zscore") - data._computed_metrics["respiratory_variance"] = dict(metric=rv_out) - return data, rv_out +@return_physio_or_metric() @physio.make_operation() def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None): """Calculate respiratory phase from respiratory signal. @@ -374,6 +374,4 @@ def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None): phase_resp[:, i_slice] = phase_resp_crSlice - data._computed_metrics["respiratory_phase"] = dict(metric=phase_resp) - return data, phase_resp diff --git a/phys2denoise/metrics/utils.py b/phys2denoise/metrics/utils.py index 34fb21b..4378866 100644 --- a/phys2denoise/metrics/utils.py +++ b/phys2denoise/metrics/utils.py @@ -1,8 +1,11 @@ """Miscellaneous utility functions for metric calculation.""" +import functools import logging import numpy as np +from loguru import logger from numpy.lib.stride_tricks import sliding_window_view as swv +from physutils.physio import Physio from scipy.interpolate import interp1d from scipy.stats import zscore @@ -332,3 +335,38 @@ def export_metric( ) return fileprefix + + +def return_physio_or_metric(*, return_physio=True): + """ + Decorator to check if the input is a Physio object. + + Parameters + ---------- + func : function + Function to be decorated + + Returns + ------- + function + Decorated function + """ + + def determine_return_type(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + physio, metric = func(*args, **kwargs) + if isinstance(args[0], Physio): + physio._computed_metrics[func.__name__] = dict( + metric=metric, args=kwargs + ) + if return_physio: + return physio, metric + else: + return metric + else: + return metric + + return wrapper + + return determine_return_type diff --git a/phys2denoise/tests/test_metrics_cardiac.py b/phys2denoise/tests/test_metrics_cardiac.py index 25faa9a..f9b69ad 100644 --- a/phys2denoise/tests/test_metrics_cardiac.py +++ b/phys2denoise/tests/test_metrics_cardiac.py @@ -28,7 +28,7 @@ def test_cardiac_phase_smoke(): slice_timings = np.linspace(0, t_r, 22)[1:-1] peaks = np.array([0.534, 0.577, 10.45, 20.66, 50.55, 90.22]) data = np.zeros(peaks.shape) - _, card_phase = cardiac.cardiac_phase( + card_phase = cardiac.cardiac_phase( data, peaks=peaks, fs=sample_rate, diff --git a/phys2denoise/tests/test_metrics_chest_belt.py b/phys2denoise/tests/test_metrics_chest_belt.py index 5e5bfd1..a4257ec 100644 --- a/phys2denoise/tests/test_metrics_chest_belt.py +++ b/phys2denoise/tests/test_metrics_chest_belt.py @@ -30,7 +30,7 @@ def test_respiratory_phase_smoke(): slice_timings = np.linspace(0, t_r, 22)[1:-1] n_samples = int(np.rint((n_scans * t_r) * sample_rate)) resp = np.random.normal(size=n_samples) - _, resp_phase = chest_belt.respiratory_phase( + resp_phase = chest_belt.respiratory_phase( resp, fs=sample_rate, slice_timings=slice_timings, @@ -46,7 +46,7 @@ def test_respiratory_pattern_variability_smoke(): n_samples = 2000 resp = np.random.normal(size=n_samples) window = 50 - _, rpv_val = chest_belt.respiratory_pattern_variability(resp, window) + rpv_val = chest_belt.respiratory_pattern_variability(resp, window) assert isinstance(rpv_val, float) @@ -56,7 +56,7 @@ def test_env_smoke(): resp = np.random.normal(size=n_samples) samplerate = 1 / 0.01 window = 6 - _, env_arr = chest_belt.env(resp, fs=samplerate, window=window) + env_arr = chest_belt.env(resp, fs=samplerate, window=window) assert env_arr.ndim == 1 assert env_arr.shape == (n_samples,) @@ -67,6 +67,6 @@ def test_respiratory_variance_smoke(): resp = np.random.normal(size=n_samples) samplerate = 1 / 0.01 window = 6 - _, rv_arr = chest_belt.respiratory_variance(resp, fs=samplerate, window=window) + rv_arr = chest_belt.respiratory_variance(resp, fs=samplerate, window=window) assert rv_arr.ndim == 2 assert rv_arr.shape == (n_samples, 2) diff --git a/phys2denoise/tests/test_rvt.py b/phys2denoise/tests/test_rvt.py index 67a7bf9..20d1690 100644 --- a/phys2denoise/tests/test_rvt.py +++ b/phys2denoise/tests/test_rvt.py @@ -17,8 +17,8 @@ def test_respiratory_variance_time(fake_phys): phys = peakdet.operations.peakfind_physio(phys) # TODO: Change to a simpler call once physutils are - # integrated to peakdet - phys, r = respiratory_variance_time( + # integrated to peakdet/prep4phys + r = respiratory_variance_time( phys.data, fs=phys.fs, peaks=phys.peaks, troughs=phys.troughs ) assert r is not None