diff --git a/CHANGES.rst b/CHANGES.rst index 0ebc1367d..fd7066686 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,4 +1,4 @@ -2.1.0 (unreleased) +2.0.1 (unreleased) ------------------ General @@ -10,6 +10,12 @@ New Features Bug Fixes ^^^^^^^^^ +- ``photutils.background`` + + - Fixed a bug in ``SExtractorBackground`` where the dimensionality of + the returned value would not be preserved if the output was a single + value. [#1934] + API Changes ^^^^^^^^^^^ diff --git a/photutils/background/core.py b/photutils/background/core.py index 90c0acce5..2317c2ce1 100644 --- a/photutils/background/core.py +++ b/photutils/background/core.py @@ -413,21 +413,28 @@ def calc_background(self, data, axis=None, masked=False): _median = np.atleast_1d(nanmedian(data, axis=axis)) _mean = np.atleast_1d(nanmean(data, axis=axis)) _std = np.atleast_1d(nanstd(data, axis=axis)) - bkg = np.atleast_1d((2.5 * _median) - (1.5 * _mean)) + bkg = (2.5 * _median) - (1.5 * _mean) - bkg = np.where(_std == 0, _mean, bkg) + # set the background to the mean where the std is zero + mean_mask = _std == 0 + bkg[mean_mask] = _mean[mean_mask] - idx = np.where(_std != 0) - condition = (np.abs(_mean[idx] - _median[idx]) / _std[idx]) < 0.3 - bkg[idx] = np.where(condition, bkg[idx], _median[idx]) - if bkg.size == 1: + # set the background to the median when the absolute + # difference between the mean and median divided by the + # standard deviation is greater than or equal to 0.3 + + med_mask = (np.abs(_mean - _median) / _std) >= 0.3 + mask = np.logical_and(med_mask, np.logical_not(mean_mask)) + bkg[mask] = _median[mask] + + # if bkg is a scalar, return it as a float + if bkg.shape == (1,) and axis is None: bkg = bkg[0] - result = bkg - if masked and isinstance(result, np.ndarray): - result = np.ma.masked_where(np.isnan(result), result) + if masked and isinstance(bkg, np.ndarray): + bkg = np.ma.masked_where(np.isnan(bkg), bkg) - return result + return bkg class BiweightLocationBackground(BackgroundBase): diff --git a/photutils/background/tests/test_core.py b/photutils/background/tests/test_core.py index db582a464..d69f75847 100644 --- a/photutils/background/tests/test_core.py +++ b/photutils/background/tests/test_core.py @@ -105,6 +105,62 @@ def test_sourceextrator_background_skew(): assert_allclose(bkg.calc_background(data), np.median(data)) +@pytest.mark.parametrize('bkg_class', BKG_CLASS) +def test_background_ndim(bkg_class): + data1 = np.ones((1, 100, 100)) + data2 = np.ones((1, 100 * 100)) + data3 = np.ones((1, 1, 100 * 100)) + data4 = np.ones((1, 1, 1, 100 * 100)) + + bkg = bkg_class(sigma_clip=None) + val = bkg(data1, axis=None) + assert np.ndim(val) == 0 + val = bkg(data1, axis=(1, 2)) + assert val.shape == (1,) + val = bkg(data1, axis=-1) + assert val.shape == (1, 100) + val = bkg(data2, axis=-1) + assert val.shape == (1,) + val = bkg(data3, axis=-1) + assert val.shape == (1, 1) + val = bkg(data4, axis=-1) + assert val.shape == (1, 1, 1) + val = bkg(data4, axis=(2, 3)) + assert val.shape == (1, 1) + val = bkg(data4, axis=(1, 2, 3)) + assert val.shape == (1,) + val = bkg(data4, axis=(0, 1, 2)) + assert val.shape == (10000,) + + +@pytest.mark.parametrize('bkgrms_class', RMS_CLASS) +def test_background_rms_ndim(bkgrms_class): + data1 = np.ones((1, 100, 100)) + data2 = np.ones((1, 100 * 100)) + data3 = np.ones((1, 1, 100 * 100)) + data4 = np.ones((1, 1, 1, 100 * 100)) + + bkgrms = bkgrms_class(sigma_clip=None) + val = bkgrms(data1, axis=None) + assert np.ndim(val) == 0 + val = bkgrms(data1, axis=(1, 2)) + assert val.shape == (1,) + val = bkgrms(data1, axis=-1) + assert val.shape == (1, 100) + val = bkgrms(data2, axis=-1) + assert val.shape == (1,) + val = bkgrms(data3, axis=-1) + assert val.shape == (1, 1) + val = bkgrms(data4, axis=-1) + assert val.shape == (1, 1, 1) + val = bkgrms(data4, axis=(2, 3)) + assert val.shape == (1, 1) + val = bkgrms(data4, axis=(1, 2, 3)) + assert val.shape == (1,) + val = bkgrms(data4, axis=(0, 1, 2)) + assert val.shape == (10000,) + + @pytest.mark.parametrize('rms_class', RMS_CLASS) def test_background_rms(rms_class): bkgrms = rms_class(sigma_clip=SIGMA_CLIP)