diff --git a/tests/test_axes.py b/tests/test_axes.py new file mode 100644 index 0000000..44ed19a --- /dev/null +++ b/tests/test_axes.py @@ -0,0 +1,43 @@ +from astropy.io import fits +import pytest + +from xarrayfits.axes import Axes + + +@pytest.fixture(params=[(10, 20, 30)]) +def header(request): + # Reverse into FORTRAN order + rev_dims = list(reversed(request.param)) + naxes = {f"NAXIS{d + 1}": s for d, s in enumerate(rev_dims)} + crpix = {f"CRPIX{d + 1}": 5 + d for d, _ in enumerate(rev_dims)} + crval = {f"CRVAL{d + 1}": 1.0 + d for d, _ in enumerate(rev_dims)} + cdelt = {f"CDELT{d + 1}": 2.0 + d for d, _ in enumerate(rev_dims)} + cunit = {f"CUNIT{d + 1}": f"UNIT-{len(rev_dims) - d}" for d in range(len(rev_dims))} + ctype = {f"CTYPE{d + 1}": f"TYPE-{len(rev_dims) - d}" for d in range(len(rev_dims))} + cname = {f"CNAME{d + 1}": f"NAME-{len(rev_dims) - d}" for d in range(len(rev_dims))} + + return fits.Header( + { + "NAXIS": len(request.param), + **naxes, + **crpix, + **crval, + **cdelt, + **cname, + **ctype, + **cunit, + } + ) + + +def test_axes(header): + axes = Axes(header) + ndims = axes.ndims + assert ndims == header["NAXIS"] + assert axes.naxis == [10, 20, 30] + assert axes.crpix == [7, 6, 5] + assert axes.crval == [3.0, 2.0, 1.0] + assert axes.cdelt == [4.0, 3.0, 2.0] + assert axes.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)] + assert axes.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)] + assert axes.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)] diff --git a/tests/test_xarrayfits.py b/tests/test_xarrayfits.py index 981e785..c27f1ac 100644 --- a/tests/test_xarrayfits.py +++ b/tests/test_xarrayfits.py @@ -153,7 +153,7 @@ def beam_cube(tmp_path_factory): def test_name_prefix(beam_cube): """Test specification of a name prefix""" (xds,) = xds_from_fits(beam_cube, prefix="beam") - assert xds.beam0.dims == ("beam0-0", "beam0-1", "beam0-2") + assert xds.beam0.dims == ("X0", "Y0", "FREQ0") def test_beam_creation(beam_cube): @@ -162,7 +162,7 @@ def test_beam_creation(beam_cube): cmp_data = cmp_data.reshape(xds.hdu0.shape) assert_array_equal(xds.hdu0.data, cmp_data) assert xds.hdu0.data.shape == (257, 257, 32) - assert xds.hdu0.dims == ("hdu0-0", "hdu0-1", "hdu0-2") + assert xds.hdu0.dims == ("X0", "Y0", "FREQ0") assert xds.hdu0.attrs == { "header": { "BITPIX": -64, diff --git a/xarrayfits/axes.py b/xarrayfits/axes.py new file mode 100644 index 0000000..322222a --- /dev/null +++ b/xarrayfits/axes.py @@ -0,0 +1,73 @@ +import numpy as np + +HEADER_PREFIXES = ["NAXIS", "CTYPE", "CRPIX", "CRVAL", "CDELT", "CUNIT", "CNAME"] + + +def property_factory(prefix): + def impl(self): + return getattr(self, f"_{prefix}") + + return property(impl) + + +class UndefinedGridError(ValueError): + pass + + +class AxesMetaClass(type): + def __new__(cls, name, bases, dct): + for prefix in (p.lower() for p in HEADER_PREFIXES): + dct[prefix] = property_factory(prefix) + return type.__new__(cls, name, bases, dct) + + +class Axes(metaclass=AxesMetaClass): + """Presents a C-ordered view over FITS Header grid attributes""" + + def __init__(self, header): + self._ndims = ndims = header["NAXIS"] + axr = tuple(range(1, ndims + 1)) + + # Read headers into C-order + for prefix in HEADER_PREFIXES: + values = reversed([header.get(f"{prefix}{n}") for n in axr]) + values = [s.strip() if isinstance(s, str) else s for s in values] + setattr(self, f"_{prefix.lower()}", values) + + # We must have all NAXIS + for i, a in enumerate(self.naxis): + if a is None: + raise UndefinedGridError(f"NAXIS{ndims - i} undefined") + + # Fill in any None CRVAL + self._crval = [0 if v is None else v for v in self._crval] + # Fill in any None CRPIX + self._crpix = [1 if p is None else p for p in self._crpix] + # Fill in any None CDELT + self._cdelt = [1 if d is None else d for d in self._cdelt] + + self._grid = [None] * ndims + + @property + def ndims(self): + return self._ndims + + def name(self, dim): + """Return a name for dimension :code:`dim`""" + if result := self.cname[dim]: + return result + elif result := self.ctype[dim]: + return result + else: + return None + + def grid(self, dim): + """Return the axis grid for dimension :code:`dim`""" + if self._grid[dim] is None: + # Create the grid + pixels = np.arange(1, self.naxis[dim] + 1) + self._grid[dim] = (pixels - self.crpix[dim]) * self.cdelt[dim] + self.crval[ + dim + ] + + return self._grid[dim] diff --git a/xarrayfits/fits.py b/xarrayfits/fits.py index 56f7d7b..38da14c 100644 --- a/xarrayfits/fits.py +++ b/xarrayfits/fits.py @@ -17,6 +17,7 @@ import xarray as xr +from xarrayfits.axes import Axes, UndefinedGridError from xarrayfits.fits_proxy import FitsProxy log = logging.getLogger("xarray-fits") @@ -160,18 +161,17 @@ def array_from_fits_hdu( shape = [] flat_chunks = [] + axes = Axes(hdu.header) - # At this point we are dealing with FORTRAN ordered axes - for i in range(naxis): - ax_key = f"NAXIS{naxis - i}" - ax_shape = hdu.header[ax_key] - shape.append(ax_shape) + # Determine shapes and apply chunking + for i in range(axes.ndims): + shape.append(axes.naxis[i]) try: # Try add existing chunking strategies to the list flat_chunks.append(chunks[i]) except KeyError: - flat_chunks.append(ax_shape) + flat_chunks.append(axes.naxis[i]) array = generate_slice_gets( fits_proxy, @@ -181,9 +181,15 @@ def array_from_fits_hdu( tuple(flat_chunks), ) - dims = tuple(f"{prefix}{hdu_index}-{i}" for i in range(0, naxis)) + dims = tuple( + f"{name}{hdu_index}" if (name := axes.name(i)) else f"{prefix}{hdu_index}-{i}" + for i in range(axes.ndims) + ) + + coords = {d: (d, axes.grid(i)) for i, d in enumerate(dims)} + attrs = {"header": {k: v for k, v in sorted(hdu.header.items())}} - return xr.DataArray(array, dims=dims, attrs=attrs) + return xr.DataArray(array, dims=tuple(dims), coords=coords, attrs=attrs) def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None):