Skip to content

Commit

Permalink
Merge pull request astropy#73 from astrofrog/support-all-dtypes
Browse files Browse the repository at this point in the history
Make sure that we accept any input numerical types
  • Loading branch information
astrofrog committed May 6, 2015
2 parents 17b2061 + e74ab32 commit 9093b42
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 15 deletions.
6 changes: 4 additions & 2 deletions reproject/healpix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def healpix_to_image(healpix_data, coord_system_in, wcs_out, shape_out,
* 'bilinear'
or an integer. A value of ``0`` indicates nearest neighbor
interpolation.
interpolation.
nested : bool
The order of the healpix_data, either nested or ring. Stored in
FITS headers in the ORDERING keyword.
Expand All @@ -68,6 +68,8 @@ def healpix_to_image(healpix_data, coord_system_in, wcs_out, shape_out,
"""
import healpy as hp

healpix_data = np.asarray(healpix_data, dtype=float)

# Look up lon, lat of pixels in reference system
yinds, xinds = np.indices(shape_out)
lon_out, lat_out = wcs_out.wcs_pix2world(xinds, yinds, 0)
Expand Down Expand Up @@ -130,7 +132,7 @@ def image_to_healpix(data, wcs_in, coord_system_out,
* 'bicubic'
or an integer. A value of ``0`` indicates nearest neighbor
interpolation.
interpolation.
nested : bool
The order of the healpix_data, either nested or ring. Stored in
FITS headers in the ORDERING keyword.
Expand Down
10 changes: 5 additions & 5 deletions reproject/healpix/tests/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..core import healpix_to_image, image_to_healpix
from ..high_level import reproject_from_healpix

from ...tests.test_high_level import ALL_DTYPES

DATA = os.path.join(os.path.dirname(__file__), 'data')

Expand All @@ -39,16 +39,16 @@ def get_reference_header(oversample=2, nside=1):


@pytest.mark.importorskip('healpy')
@pytest.mark.parametrize("nside,nested,healpix_system,image_system",
itertools.product([1, 2, 4, 8, 16, 32, 64], [True, False], 'C', 'C'))
@pytest.mark.parametrize("nside,nested,healpix_system,image_system,dtype",
itertools.product([1, 2, 4, 8, 16, 32, 64], [True, False], 'C', 'C', ALL_DTYPES))
def test_reproject_healpix_to_image_round_trip(
nside, nested, healpix_system, image_system):
nside, nested, healpix_system, image_system, dtype):
"""Test round-trip HEALPix->WCS->HEALPix conversion for a random map,
with a WCS projection large enough to store each HEALPix pixel"""
import healpy as hp

npix = hp.nside2npix(nside)
healpix_data = np.random.uniform(size=npix)
healpix_data = np.random.uniform(size=npix).astype(dtype)

reference_header = get_reference_header(oversample=2, nside=nside)

Expand Down
11 changes: 6 additions & 5 deletions reproject/spherical_intersect/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
unicode_literals)

import signal
import warnings

import numpy as np

from astropy import log as logger
from astropy.utils.exceptions import AstropyUserWarning

from ..wcs_utils import convert_world_coordinates

Expand Down Expand Up @@ -35,8 +36,8 @@ def _reproject_celestial(array, wcs_in, wcs_out, shape_out, parallel=True, _meth
nproc = None if parallel else 1

# Convert input array to float values. If this comes from a FITS, it might have
# float32 as value type and that can break things in cythin.
array = array.astype(float)
# float32 as value type and that can break things in Cython
array = np.asarray(array, dtype=float)

# TODO: make this work for n-dimensional arrays
if wcs_in.naxis != 2:
Expand Down Expand Up @@ -196,8 +197,8 @@ def parallel_impl(nproc):
# the serial one.
raise
except Exception as e:
logger.warn("The parallel implementation failed, the reported error message is: '{0}'".format(repr(e,)))
logger.warn("Running the serial implementation instead")
warnings.warn("The parallel implementation failed, the reported error message is: '{0}'".format(repr(e,)), AstropyUserWarning)
warnings.warn("Running the serial implementation instead", AstropyUserWarning)
return serial_impl()

raise ValueError('unrecognized method "{0}"'.format(_method,))
16 changes: 13 additions & 3 deletions reproject/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)

import itertools

import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
Expand All @@ -18,6 +20,14 @@
'bicubic',
'flux-conserving')

ALL_DTYPES = []
for endian in ('<', '>'):
for kind in ('u', 'i', 'f'):
for size in ('1', '2', '4', '8'):
if kind == 'f' and size == '1':
continue
ALL_DTYPES.append(np.dtype(endian + kind + size))


class TestReproject(object):

Expand Down Expand Up @@ -82,8 +92,8 @@ def test_array_header_header(self):
"""


@pytest.mark.parametrize('projection_type', ALL_MODES)
def test_surface_brightness(projection_type):
@pytest.mark.parametrize('projection_type, dtype', itertools.product(ALL_MODES, ALL_DTYPES))
def test_surface_brightness(projection_type, dtype):

header_in = fits.Header.fromstring(INPUT_HDR, sep='\n')
header_in['NAXIS'] = 2
Expand All @@ -97,7 +107,7 @@ def test_surface_brightness(projection_type):
header_out['NAXIS1'] *= 2
header_out['NAXIS2'] *= 2

data_in = np.ones((10, 10))
data_in = np.ones((10, 10), dtype=dtype)

if projection_type == 'flux-conserving':
data_out, footprint = reproject_exact((data_in, header_in), header_out)
Expand Down

0 comments on commit 9093b42

Please sign in to comment.