Skip to content

Commit

Permalink
Respect the bounding_box in inverse transforms (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
nden authored Dec 13, 2024
1 parent 8b4d8d7 commit 863b6e4
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 74 deletions.
8 changes: 6 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

- Add support for compound bounding boxes and ignored bounding box entries. [#519]


- Add ``gwcs.examples`` module, based on the examples located in the testing ``conftest.py``. [#521]

- Force ``bounding_box`` to always be returned as a ``F`` ordered box. [#522]
Expand All @@ -19,8 +20,11 @@

- Adjust ``world_to_array_index_values`` to round to integer coordinates as specified by APE 14. [#525]

- Add warning filter to asdf extension to prevent the ``bounding_box`` order warning for gwcs
objects originating from a file. [#526]
- Add warning filter to asdf extension to prevent the ``bounding_box`` order warning for gwcs objects originating from a file. [#526]

- Fixed a bug where evaluating the inverse transform did not
respect the bounding box. [#498]


0.21.0 (2024-03-10)
-------------------
Expand Down
9 changes: 7 additions & 2 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ def world_to_pixel_values(self, *world_arrays):
be returned in the ``(x, y)`` order, where for an image, ``x`` is the
horizontal coordinate and ``y`` is the vertical coordinate.
"""
world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame)
try:
backward_transform = self.backward_transform
world_arrays = self._add_units_input(world_arrays,
backward_transform,
self.output_frame)
except NotImplementedError:
pass

result = self.invert(*world_arrays, with_units=False)

Expand Down Expand Up @@ -317,7 +323,6 @@ def world_to_pixel(self, *world_objects):
Convert world coordinates to pixel values.
"""
result = self.invert(*world_objects, with_units=True)

if self.input_frame.naxes > 1:
first_res = result[0]
if not utils.isnumerical(first_res):
Expand Down
2 changes: 1 addition & 1 deletion gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def sellmeier_zemax():

@pytest.fixture(scope="function")
def gwcs_3d_galactic_spectral():
return examples.gwcs_3d_galactic_spectral()

return examples.gwcs_3d_galactic_spectral()

@pytest.fixture(scope="function")
def gwcs_1d_spectral():
Expand Down
17 changes: 4 additions & 13 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def wcs_ndim_types_units(request):

@fixture_all_wcses
def test_lowlevel_types(wcsobj):
pytest.importorskip("typeguard")
try:
# Skip this on older versions of astropy where it dosen't exist.
from astropy.wcs.wcsapi.tests.utils import validate_low_level_wcs_types
Expand Down Expand Up @@ -236,12 +235,12 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units):
def _compare_frame_output(wc1, wc2):
if isinstance(wc1, coord.SkyCoord):
assert isinstance(wc1.frame, type(wc2.frame))
assert u.allclose(wc1.spherical.lon, wc2.spherical.lon)
assert u.allclose(wc1.spherical.lat, wc2.spherical.lat)
assert u.allclose(wc1.spherical.distance, wc2.spherical.distance)
assert u.allclose(wc1.spherical.lon, wc2.spherical.lon, equal_nan=True)
assert u.allclose(wc1.spherical.lat, wc2.spherical.lat, equal_nan=True)
assert u.allclose(wc1.spherical.distance, wc2.spherical.distance, equal_nan=True)

elif isinstance(wc1, u.Quantity):
assert u.allclose(wc1, wc2)
assert u.allclose(wc1, wc2, equal_nan=True)

elif isinstance(wc1, time.Time):
assert u.allclose((wc1 - wc2).to(u.s), 0*u.s)
Expand All @@ -258,12 +257,6 @@ def _compare_frame_output(wc1, wc2):

@fixture_all_wcses
def test_high_level_wrapper(wcsobj, request):
if request.node.callspec.params['wcsobj'] in ('gwcs_4d_identity_units', 'gwcs_stokes_lookup'):
pytest.importorskip("astropy", minversion="4.0dev0")

# Remove the bounding box because the type test is a little broken with the
# bounding box.
del wcsobj._pipeline[0].transform.bounding_box

hlvl = HighLevelWCSWrapper(wcsobj)

Expand All @@ -286,8 +279,6 @@ def test_high_level_wrapper(wcsobj, request):


def test_stokes_wrapper(gwcs_stokes_lookup):
pytest.importorskip("astropy", minversion="4.0dev0")

hlvl = HighLevelWCSWrapper(gwcs_stokes_lookup)

pixel_input = [0, 1, 2, 3]
Expand Down
4 changes: 2 additions & 2 deletions gwcs/tests/test_api_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def test_celestial_slice(gwcs_3d_galactic_spectral):
assert_allclose(wcs.pixel_to_world_values(39, 44), (10.24, 20, 25))
assert_allclose(wcs.array_index_to_world_values(44, 39), (10.24, 20, 25))

assert_allclose(wcs.world_to_pixel_values(12.4, 20, 25), (39., 44.))
assert_equal(wcs.world_to_array_index_values(12.4, 20, 25), (44, 39))
assert_allclose(wcs.world_to_pixel_values(10.24, 20, 25), (39., 44.))
assert_equal(wcs.world_to_array_index_values(10.24, 20, 25), (44, 39))

assert_equal(wcs.pixel_bounds, [(-2, 45), (5, 50)])

Expand Down
87 changes: 87 additions & 0 deletions gwcs/tests/test_bounding_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
from numpy.testing import assert_array_equal, assert_allclose

import pytest


x = [-1, 2, 4, 13]
y = [np.nan, np.nan, 4, np.nan]
y1 = [np.nan, np.nan, 4, np.nan]


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x),(y, y))
])
def test_2d_spatial(gwcs_2d_spatial_shift, input, output):
w = gwcs_2d_spatial_shift
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x), (y, y))
])
def test_2d_spatial_coordinate(gwcs_2d_quantity_shift, input, output):
w = gwcs_2d_quantity_shift
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x), (y, y))
])
def test_2d_spatial_coordinate_reordered(gwcs_2d_spatial_reordered, input, output):
w = gwcs_2d_spatial_reordered
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [(2, 2),
((10, 200), (10, np.nan)),
(x, (np.nan, 2, 4, 13))
])
def test_1d_freq(gwcs_1d_freq, input, output):
w = gwcs_1d_freq
w.bounding_box = (-.5, 21)
print(f"input {input}, {output}")
assert_array_equal(w.invert(w(input)), output)
assert_array_equal(w.world_to_pixel_values(w.pixel_to_world_values(input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5)),
((100, 200, 5), (np.nan, np.nan, np.nan)),
((x, x, x), (y1, y1, y1))
])
def test_3d_spatial_wave(gwcs_3d_spatial_wave, input, output):
w = gwcs_3d_spatial_wave
w.bounding_box = ((-.5, 21), (4, 12), (3, 21))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((1, 2, 3, 4), (1., 2., 3., 4.)),
((100, 3, 3, 3), (np.nan, 3, 3, 3)),
((x, x, x, x), [[np.nan, 2., 4., 13.],
[np.nan, 2., 4., 13.],
[np.nan, 2., 4., 13.],
[np.nan, 2., 4., np.nan]])
])
def test_gwcs_spec_cel_time_4d(gwcs_spec_cel_time_4d, input, output):
w = gwcs_spec_cel_time_4d

assert_allclose(w.invert(*w(*input, with_bounding_box=False)), output, atol=1e-8)
3 changes: 0 additions & 3 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def test_temporal_relative():
assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s


@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher")
def test_temporal_absolute():
t = cf.TemporalFrame(reference_frame=Time([], format='isot'))
assert t.coordinates("2018-01-01T00:00:00") == Time("2018-01-01T00:00:00")
Expand Down Expand Up @@ -240,7 +239,6 @@ def test_coordinate_to_quantity_spectral(inp):
(Time("2011-01-01T00:00:10"),),
(10 * u.s,)
])
@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
def test_coordinate_to_quantity_temporal(inp):
temp = cf.TemporalFrame(reference_frame=Time("2011-01-01T00:00:00"), unit=u.s)

Expand Down Expand Up @@ -325,7 +323,6 @@ def test_coordinate_to_quantity_frame_2d():
assert_quantity_allclose(output, exp)


@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
def test_coordinate_to_quantity_error():
frame = cf.Frame2D(unit=(u.one, u.arcsec))
with pytest.raises(ValueError):
Expand Down
6 changes: 6 additions & 0 deletions gwcs/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from astropy import units as u
from astropy import coordinates as coord
from astropy.modeling import models
from astropy import table

from astropy.tests.helper import assert_quantity_allclose
import pytest
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -104,6 +106,10 @@ def test_isnumerical():
assert gwutils.isnumerical(np.array(0, dtype='>f8'))
assert gwutils.isnumerical(np.array(0, dtype='>i4'))

# check a table column
t = table.Table(data=[[1,2,3], [4,5,6]], names=['x', 'y'])
assert not gwutils.isnumerical(t['x'])


def test_get_values():
args = 2 * u.cm
Expand Down
8 changes: 6 additions & 2 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,8 +1163,8 @@ def test_in_image():

assert np.isscalar(w2.in_image(2, 6))
assert not np.isscalar(w2.in_image([2], [6]))
assert w2.in_image(4, 6)
assert not w2.in_image(5, 0)
assert (w2.in_image(4, 6))
assert not (w2.in_image(5, 0))
assert np.array_equal(
w2.in_image(
[[9, 10, 11, 15], [8, 9, 67, 98], [2, 2, np.nan, 102]],
Expand Down Expand Up @@ -1199,6 +1199,7 @@ def test_iter_inv():
*w(x, y),
adaptive=True,
detect_divergence=True,
tolerance=1e-4, maxiter=50,
quiet=False
)
assert np.allclose((x, y), (xp, yp))
Expand All @@ -1218,6 +1219,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y),
adaptive=True,
tolerance=1e-5, maxiter=50,
detect_divergence=False,
quiet=False
)
Expand Down Expand Up @@ -1252,6 +1254,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y, with_bounding_box=False),
adaptive=False,
tolerance=1e-5, maxiter=50,
detect_divergence=True,
quiet=False,
with_bounding_box=False
Expand All @@ -1265,6 +1268,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y, with_bounding_box=False),
adaptive=False,
tolerance=1e-5, maxiter=50,
detect_divergence=True,
quiet=False,
with_bounding_box=False
Expand Down
11 changes: 5 additions & 6 deletions gwcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astropy import coordinates as coords
from astropy import units as u
from astropy.time import Time, TimeDelta
from astropy import table
from astropy.wcs import Celprm


Expand Down Expand Up @@ -470,14 +471,12 @@ def isnumerical(val):
Determine if a value is numerical (number or np.array of numbers).
"""
isnum = True
if isinstance(val, coords.SkyCoord):
isnum = False
elif isinstance(val, u.Quantity):
isnum = False
elif isinstance(val, (Time, TimeDelta)):
astropy_types=(coords.SkyCoord, u.Quantity, Time, TimeDelta, table.Column, table.Row)
if isinstance(val, astropy_types):
isnum = False
elif (isinstance(val, np.ndarray)
and not np.issubdtype(val.dtype, np.floating)
and not np.issubdtype(val.dtype, np.integer)):
and not np.issubdtype(val.dtype, np.integer)
):
isnum = False
return isnum
Loading

0 comments on commit 863b6e4

Please sign in to comment.