Skip to content

Commit

Permalink
Merge pull request astropy#80 from astrofrog/refactor-parallel-exact
Browse files Browse the repository at this point in the history
Simplify the handling of the parallel processing in exact mode
  • Loading branch information
astrofrog committed May 7, 2015
2 parents 3c892cc + 28a265d commit ceee0b4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 61 deletions.
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ install:
# conda for packages available through conda, or pip for any other
# packages. You should leave the `numpy=$NUMPY_VERSION` in the `conda`
# install since this ensures Numpy does not get automatically upgraded.
- if [[ $OPTIONAL_DEPS == true ]]; then $CONDA_INSTALL numpy=$NUMPY_VERSION matplotlib; fi
- if [[ $OPTIONAL_DEPS == true ]]; then $PIP_INSTALL healpy; fi

# DOCUMENTATION DEPENDENCIES
Expand All @@ -109,7 +110,8 @@ install:
- if [[ $SETUP_CMD == build_sphinx* ]]; then $CONDA_INSTALL numpy=$NUMPY_VERSION Sphinx=1.2 pyqt matplotlib; fi

# COVERAGE DEPENDENCIES
- if [[ $SETUP_CMD == 'test --coverage' ]]; then $PIP_INSTALL coverage coveralls; fi
- if [[ $SETUP_CMD == 'test --coverage' ]]; then $CONDA_INSTALL coverage pyyaml requests; fi
- if [[ $SETUP_CMD == 'test --coverage' ]]; then $PIP_INSTALL coveralls; fi

script:
- python setup.py $SETUP_CMD
Expand Down
96 changes: 40 additions & 56 deletions reproject/spherical_intersect/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
unicode_literals)

import signal
import warnings

import numpy as np

from astropy.utils.exceptions import AstropyUserWarning

from ..wcs_utils import convert_world_coordinates

from ._overlap import _compute_overlap


# Function to disable ctrl+c in the worker processes.
def _init_worker():
"""
Function to disable ctrl+c in the worker processes.
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)


def _reproject_celestial(array, wcs_in, wcs_out, shape_out, parallel=True, _method="c"):
def _reproject_slice(args):
from ._overlap import _reproject_slice_cython
return _reproject_slice_cython(*args)


def _reproject_celestial(array, wcs_in, wcs_out, shape_out, parallel=True, _legacy=False):

# Check the parallel flag.
if type(parallel) != bool and type(parallel) != int:
Expand Down Expand Up @@ -79,7 +83,7 @@ def _reproject_celestial(array, wcs_in, wcs_out, shape_out, parallel=True, _meth

xp_inout, yp_inout = wcs_out.wcs_world2pix(xw_in, yw_in, 0)

if _method == "legacy":
if _legacy:
# Create output image

array_new = np.zeros(shape_out)
Expand Down Expand Up @@ -130,69 +134,49 @@ def _reproject_celestial(array, wcs_in, wcs_out, shape_out, parallel=True, _meth
# Put together the parameters common both to the serial and parallel implementations. The aca
# function is needed to enforce that the array will be contiguous when passed to the low-level
# raw C function, otherwise Cython might complain.
from numpy import ascontiguousarray as aca
from ._overlap import _reproject_slice_cython
common_func_par = [0, ny_in, nx_out, ny_out, aca(xp_inout), aca(yp_inout), aca(xw_in), aca(yw_in), aca(xw_out), aca(yw_out), aca(array), shape_out]

# Abstract the serial implementation in a separate function so we can reuse it.
def serial_impl():
array_new, weights = _reproject_slice_cython(0, nx_in, *common_func_par)
aca = np.ascontiguousarray
common_func_par = [0, ny_in, nx_out, ny_out, aca(xp_inout), aca(yp_inout),
aca(xw_in), aca(yw_in), aca(xw_out), aca(yw_out), aca(array),
shape_out]

array_new /= weights
if nproc == 1:

array_new, weights = _reproject_slice([0, nx_in] + common_func_par)

with np.errstate(invalid='ignore'):
array_new /= weights

return array_new, weights

if _method == "c" and nproc == 1:
return serial_impl()
elif (nproc is None or nproc > 1):

# Abstract the parallel implementation as well.
def parallel_impl(nproc):
from multiprocessing import Pool, cpu_count

# If needed, establish the number of processors to use.
if nproc is None:
nproc = cpu_count()

# Create the pool.
pool = None
try:
# Prime each process in the pool with a small function that disables
# the ctrl+c signal in the child process.
pool = Pool(nproc, _init_worker)

# Accumulator for the results from the parallel processes.
results = []

for i in range(nproc):
start = int(nx_in) // nproc * i
end = int(nx_in) if i == nproc - 1 else int(nx_in) // nproc * (i + 1)
results.append(pool.apply_async(_reproject_slice_cython, [start, end] + common_func_par))

array_new = sum([_.get()[0] for _ in results])
weights = sum([_.get()[1] for _ in results])

except KeyboardInterrupt: # pragma: no cover
# If we hit ctrl+c while running things in parallel, we want to terminate
# everything and erase the pool before re-raising. Note that since we inited the pool
# with the _init_worker function, we disabled catching ctrl+c from the subprocesses. ctrl+c
# can be handled only by the main process.
if not pool is None:
pool.terminate()
pool.join()
pool = None
raise

finally:
if not pool is None:
# Clean up the pool, if still alive.
pool.close()
pool.join()
# Prime each process in the pool with a small function that disables
# the ctrl+c signal in the child process.
pool = Pool(nproc, _init_worker)

inputs = []
for i in range(nproc):
start = int(nx_in) // nproc * i
end = int(nx_in) if i == nproc - 1 else int(nx_in) // nproc * (i + 1)
inputs.append([start, end] + common_func_par)

results = pool.map(_reproject_slice, inputs)

pool.close()

array_new, weights = zip(*results)

array_new = sum(array_new)
weights = sum(weights)

with np.errstate(invalid='ignore'):
array_new /= weights

return array_new, weights

if _method == "c" and (nproc is None or nproc > 1):
return parallel_impl(nproc)

raise ValueError('unrecognized method "{0}"'.format(_method,))
8 changes: 4 additions & 4 deletions reproject/spherical_intersect/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_reproject_celestial_consistency():
wcs_in = WCS(fits.Header.fromstring(INPUT_HDR, sep='\n'))
wcs_out = WCS(fits.Header.fromstring(OUTPUT_HDR, sep='\n'))

array1, footprint1 = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), _method='legacy')
array2, footprint2 = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), _method='c', parallel=False)
array3, footprint3 = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), _method='c', parallel=True)
array1, footprint1 = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), _legacy=True)
array2, footprint2 = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), parallel=False)
array3, footprint3 = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), parallel=True)

np.testing.assert_allclose(array1, array2, rtol=1.e-6)
np.testing.assert_allclose(array1, array3, rtol=1.e-6)
Expand All @@ -88,7 +88,7 @@ def test_reproject_celestial_():
wcs_in = WCS(fits.Header.fromstring(INPUT_HDR, sep='\n'))
wcs_out = WCS(fits.Header.fromstring(OUTPUT_HDR, sep='\n'))

array, footprint = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), _method='c', parallel=False)
array, footprint = _reproject_celestial(DATA, wcs_in, wcs_out, (4, 4), parallel=False)

# TODO: improve agreement with Montage - at the moment agreement is ~10%
np.testing.assert_allclose(array, MONTAGE_REF, rtol=0.09)

0 comments on commit ceee0b4

Please sign in to comment.