Skip to content

Commit

Permalink
check for type instead of manually specifying
Browse files Browse the repository at this point in the history
  • Loading branch information
keflavich committed Oct 12, 2024
1 parent 3591c16 commit e2bd76f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
8 changes: 4 additions & 4 deletions spectral_cube/spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ def argmax(self, axis=None, how='auto', **kwargs):
return self.apply_numpy_function(np.nanargmax, fill=-np.inf,
reduce=False, projection=False,
how=how, axis=axis,
dtype='int',
**kwargs)

@aggregation_docstring
Expand All @@ -811,7 +810,6 @@ def argmin(self, axis=None, how='auto', **kwargs):
return self.apply_numpy_function(np.nanargmin, fill=np.inf,
reduce=False, projection=False,
how=how, axis=axis,
dtype='int',
**kwargs)

def _argmaxmin_world(self, axis, method, **kwargs):
Expand Down Expand Up @@ -1001,7 +999,7 @@ def _cube_on_cube_operation(self, function, cube, equivalencies=[], **kwargs):

def apply_function(self, function, axis=None, weights=None, unit=None,
projection=False, progressbar=False,
update_function=None, keep_shape=False, dtype=None,
update_function=None, keep_shape=False,
**kwargs):
"""
Apply a function to valid data along the specified axis or to the whole
Expand Down Expand Up @@ -1059,7 +1057,9 @@ def apply_function(self, function, axis=None, weights=None, unit=None,
nz = self.shape[axis] if keep_shape else 1

# allocate memory for output array
if 'int' in str(dtype):
# check dtype first (for argmax/argmin)
result = function(np.arange(3, dtype=self._data.dtype))
if 'int' in str(result.dtype):
out = np.zeros([nz, nx, ny], dtype=dtype)
else:
out = np.empty([nz, nx, ny]) * np.nan
Expand Down
5 changes: 2 additions & 3 deletions spectral_cube/tests/test_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,12 +664,11 @@ def test_argmin(self):
self._check_numpy(self.c.argmin, d, np.nanargmin)
self.c = self.d = None

def test_arg_rays(self):
def test_arg_rays(self, use_dask):
"""
regression test: argmax must have integer dtype
"""
# only test for non-dask
if not hasattr(self.c, 'rechunk'):
if not use_dask:
result = self.c.argmax(how='ray')
assert 'int' in str(result.dtype)
result = self.c.argmin(how='ray')
Expand Down

0 comments on commit e2bd76f

Please sign in to comment.