diff --git a/phys2denoise/metrics/cardiac.py b/phys2denoise/metrics/cardiac.py index 72c7e8d..63f1475 100644 --- a/phys2denoise/metrics/cardiac.py +++ b/phys2denoise/metrics/cardiac.py @@ -17,7 +17,7 @@ def _cardiac_metrics( peaks=None, window=6, central_measure="mean", - return_physio=False, + **kwargs, ): """ Compute cardiac metrics. @@ -154,7 +154,7 @@ def _cardiac_metrics( @due.dcite(references.CHANG_CUNNINGHAM_GLOVER_2009) @return_physio_or_metric() @physio.make_operation() -def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"): +def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean", **kwargs): """ Compute average heart rate (HR) in a sliding window. @@ -213,6 +213,7 @@ def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"): peaks=peaks, window=window, central_measure=central_measure, + **kwargs, ) return data, hr @@ -221,7 +222,9 @@ def heart_rate(data, fs=None, peaks=None, window=6, central_measure="mean"): @due.dcite(references.PINHERO_ET_AL_2016) @return_physio_or_metric() @physio.make_operation() -def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure="mean"): +def heart_rate_variability( + data, fs=None, peaks=None, window=6, central_measure="mean", **kwargs +): """ Compute average heart rate variability (HRV) in a sliding window. @@ -278,6 +281,7 @@ def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure= peaks=peaks, window=window, central_measure=central_measure, + **kwargs, ) return data, hrv @@ -286,7 +290,9 @@ def heart_rate_variability(data, fs=None, peaks=None, window=6, central_measure= @due.dcite(references.CHEN_2020) @return_physio_or_metric() @physio.make_operation() -def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="mean"): +def heart_beat_interval( + data, fs=None, peaks=None, window=6, central_measure="mean", **kwargs +): """ Compute average heart beat interval (HBI) in a sliding window. @@ -336,6 +342,7 @@ def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="me peaks=peaks, window=window, central_measure=central_measure, + **kwargs, ) return data, hbi @@ -343,7 +350,7 @@ def heart_beat_interval(data, fs=None, peaks=None, window=6, central_measure="me @return_physio_or_metric() @physio.make_operation() -def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None): +def cardiac_phase(data, slice_timings, n_scans, t_r, fs=None, peaks=None, **kwargs): """Calculate cardiac phase from cardiac peaks. Assumes that timing of cardiac events are given in same units diff --git a/phys2denoise/metrics/chest_belt.py b/phys2denoise/metrics/chest_belt.py index 0a7b800..deae46f 100644 --- a/phys2denoise/metrics/chest_belt.py +++ b/phys2denoise/metrics/chest_belt.py @@ -16,7 +16,7 @@ @return_physio_or_metric() @physio.make_operation() def respiratory_variance_time( - data, fs=None, peaks=None, troughs=None, lags=(0, 4, 8, 12) + data, fs=None, peaks=None, troughs=None, lags=(0, 4, 8, 12), **kwargs ): """ Implement the Respiratory Variance over Time (Birn et al. 2006). @@ -113,7 +113,7 @@ def respiratory_variance_time( @due.dcite(references.POWER_2018) @return_physio_or_metric() @physio.make_operation() -def respiratory_pattern_variability(data, window): +def respiratory_pattern_variability(data, window, **kwargs): """Calculate respiratory pattern variability. Parameters @@ -161,7 +161,7 @@ def respiratory_pattern_variability(data, window): @due.dcite(references.POWER_2020) @return_physio_or_metric() @physio.make_operation() -def env(data, fs=None, window=10): +def env(data, fs=None, window=10, **kwargs): """Calculate respiratory pattern variability across a sliding window. Parameters @@ -245,7 +245,7 @@ def _respiratory_pattern_variability(data, window): @due.dcite(references.CHANG_GLOVER_2009) @return_physio_or_metric() @physio.make_operation() -def respiratory_variance(data, fs=None, window=6): +def respiratory_variance(data, fs=None, window=6, **kwargs): """Calculate respiratory variance. Parameters @@ -310,7 +310,7 @@ def respiratory_variance(data, fs=None, window=6): @return_physio_or_metric() @physio.make_operation() -def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None): +def respiratory_phase(data, n_scans, slice_timings, t_r, fs=None, **kwargs): """Calculate respiratory phase from respiratory signal. Parameters diff --git a/phys2denoise/metrics/multimodal.py b/phys2denoise/metrics/multimodal.py index fc48213..85f168f 100644 --- a/phys2denoise/metrics/multimodal.py +++ b/phys2denoise/metrics/multimodal.py @@ -20,6 +20,7 @@ def retroicor( physio_type=None, fs=None, cardiac_peaks=None, + **kwargs, ): """Compute RETROICOR regressors. diff --git a/phys2denoise/metrics/utils.py b/phys2denoise/metrics/utils.py index 4378866..edd66ec 100644 --- a/phys2denoise/metrics/utils.py +++ b/phys2denoise/metrics/utils.py @@ -360,8 +360,9 @@ def wrapper(*args, **kwargs): physio._computed_metrics[func.__name__] = dict( metric=metric, args=kwargs ) - if return_physio: - return physio, metric + return_physio_value = kwargs.get("return_physio", return_physio) + if return_physio_value: + return physio else: return metric else: diff --git a/phys2denoise/tests/test_metrics_cardiac.py b/phys2denoise/tests/test_metrics_cardiac.py index f9b69ad..839b040 100644 --- a/phys2denoise/tests/test_metrics_cardiac.py +++ b/phys2denoise/tests/test_metrics_cardiac.py @@ -50,12 +50,28 @@ def test_cardiac_phase_smoke_physio_obj(): data = np.zeros(peaks.shape) phys = physio.Physio(data, sample_rate, physio_type="cardiac") phys._metadata["peaks"] = peaks - phys, card_phase = cardiac.cardiac_phase( + + # Test where the physio object is returned + phys = cardiac.cardiac_phase( phys, slice_timings=slice_timings, n_scans=n_scans, t_r=t_r, ) assert phys.history[0][0] == "phys2denoise.metrics.cardiac.cardiac_phase" + assert phys.computed_metrics["cardiac_phase"]["metric"].ndim == 2 + assert phys.computed_metrics["cardiac_phase"]["metric"].shape == ( + n_scans, + slice_timings.size, + ) + + # Test where the metric is returned + card_phase = cardiac.cardiac_phase( + phys, + slice_timings=slice_timings, + n_scans=n_scans, + t_r=t_r, + return_physio=False, + ) assert card_phase.ndim == 2 assert card_phase.shape == (n_scans, slice_timings.size)