Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/argmax_rays' into cubewcs_mosaic…
Browse files Browse the repository at this point in the history
…_and_dask_reproject
  • Loading branch information
keflavich committed Oct 12, 2024
2 parents 1131fb3 + e8014e2 commit 92bab69
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
8 changes: 4 additions & 4 deletions spectral_cube/spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ def argmax(self, axis=None, how='auto', **kwargs):
return self.apply_numpy_function(np.nanargmax, fill=-np.inf,
reduce=False, projection=False,
how=how, axis=axis,
dtype='int',
**kwargs)

@aggregation_docstring
Expand All @@ -811,7 +810,6 @@ def argmin(self, axis=None, how='auto', **kwargs):
return self.apply_numpy_function(np.nanargmin, fill=np.inf,
reduce=False, projection=False,
how=how, axis=axis,
dtype='int',
**kwargs)

def _argmaxmin_world(self, axis, method, **kwargs):
Expand Down Expand Up @@ -1005,7 +1003,7 @@ def _cube_on_cube_operation(self, function, cube, equivalencies=[], **kwargs):

def apply_function(self, function, axis=None, weights=None, unit=None,
projection=False, progressbar=False,
update_function=None, keep_shape=False, dtype=None,
update_function=None, keep_shape=False,
**kwargs):
"""
Apply a function to valid data along the specified axis or to the whole
Expand Down Expand Up @@ -1063,7 +1061,9 @@ def apply_function(self, function, axis=None, weights=None, unit=None,
nz = self.shape[axis] if keep_shape else 1

# allocate memory for output array
if 'int' in str(dtype):
# check dtype first (for argmax/argmin)
result = function(np.arange(3, dtype=self._data.dtype), **kwargs)
if 'int' in str(result.dtype):
out = np.zeros([nz, nx, ny], dtype=dtype)
else:
out = np.empty([nz, nx, ny]) * np.nan
Expand Down
11 changes: 6 additions & 5 deletions spectral_cube/tests/test_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,14 +664,15 @@ def test_argmin(self):
self._check_numpy(self.c.argmin, d, np.nanargmin)
self.c = self.d = None

def test_arg_rays(self):
def test_arg_rays(self, use_dask):
"""
regression test: argmax must have integer dtype
"""
result = self.c.argmax(how='ray')
assert 'int' in str(result.dtype)
result = self.c.argmin(how='ray')
assert 'int' in str(result.dtype)
if not use_dask:
result = self.c.argmax(how='ray')
assert 'int' in str(result.dtype)
result = self.c.argmin(how='ray')
assert 'int' in str(result.dtype)

@pytest.mark.parametrize('iterate_rays', (True,False))
def test_median(self, iterate_rays, use_dask):
Expand Down

0 comments on commit 92bab69

Please sign in to comment.