Skip to content

Commit

Permalink
Add high level <> values converters to frames
Browse files Browse the repository at this point in the history
This adds back a more sane equivalent of coordinates and coordinates_to_quantity.
  • Loading branch information
Cadair committed Nov 20, 2024
1 parent 3ed5458 commit f8cad99
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
56 changes: 56 additions & 0 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
from astropy.wcs.wcsapi.low_level_api import (validate_physical_types,
VALID_UCDS)
from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1
from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects
from astropy.coordinates import StokesCoord

__all__ = ['BaseCoordinateFrame', 'Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame',
Expand Down Expand Up @@ -507,6 +508,61 @@ def world_axis_object_classes(self):
def _native_world_axis_object_components(self):
return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)]

@property
def serialized_classes(self):
"""
This property is used by the low level WCS API in Astropy.
By providing it we can duck type as a low level WCS object.
"""
return False

def to_high_level_coordinates(self, *values):
"""
Convert "values" to high level coordinate objects described by this frame.
"values" are the coordinates in array or scalar form, and high level
objects are things such as ``SkyCoord`` or ``Quantity``. See
:ref:`wcsapi` for details.
Parameters
----------
values : `numbers.Number` or `numpy.ndarray`
``naxis`` number of coordinates as scalars or arrays.
Returns
-------
high_level_coordinates
One (or more) high level object describing the coordinate.
"""
high_level = values_to_high_level_objects(*values, low_level_wcs=self)
if len(high_level) == 1:
high_level = high_level[0]
return high_level

def from_high_level_coordinates(self, *high_level_coords):
"""
Convert high level coordinate objects to "values" as described by this frame.
"values" are the coordinates in array or scalar form, and high level
objects are things such as ``SkyCoord`` or ``Quantity``. See
:ref:`wcsapi` for details.
Parameters
----------
high_level_coordinates
One (or more) high level object describing the coordinate.
Returns
-------
values : `numbers.Number` or `numpy.ndarray`
``naxis`` number of coordinates as scalars or arrays.
"""
values = high_level_objects_to_values(*high_level_coords, low_level_wcs=self)
if len(values) == 1:
values = values[0]
return values


class CelestialFrame(CoordinateFrame):
"""
Expand Down
23 changes: 4 additions & 19 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from .. import WCS
from .. import coordinate_frames as cf

from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values
import astropy
astropy_version = astropy.__version__

Expand Down Expand Up @@ -56,19 +55,6 @@
inputs3 = [(xscalar, yscalar, xscalar), (xarr, yarr, xarr)]


@pytest.fixture(autouse=True, scope="module")
def serialized_classes():
"""
In the rest of this test file we are passing the CoordinateFrame object to
astropy helper functions as if they were a low level WCS object.
This little patch means that this works.
"""
cf.CoordinateFrame.serialized_classes = False
yield
del cf.CoordinateFrame.serialized_classes


def test_units():
assert(comp1.unit == (u.deg, u.deg, u.Hz))
assert(comp2.unit == (u.m, u.m, u.m))
Expand All @@ -81,14 +67,13 @@ def test_units():
# These two functions fake the old methods on CoordinateFrame to reduce the
# amount of refactoring that needed doing in these tests.
def coordinates(*inputs, frame):
results = values_to_high_level_objects(*inputs, low_level_wcs=frame)
if isinstance(results, list) and len(results) == 1:
return results[0]
return results
return frame.to_high_level_coordinates(*inputs)


def coordinate_to_quantity(*inputs, frame):
results = high_level_objects_to_values(*inputs, low_level_wcs=frame)
results = frame.from_high_level_coordinates(*inputs)
if not isinstance(results, list):
results = [results]
results = [r << unit for r, unit in zip(results, frame.unit)]
return results

Expand Down

0 comments on commit f8cad99

Please sign in to comment.