Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic Affine Grid Coordinates to xarray Datasets #35

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add basic Affine Grid coordinates to xarray datasets (:pr:`35`)
* Constrain dask versions (:pr:`34`)
* Specify dtype during chunk normalisation (:pr:`33`)
* Configure dependabot for github actions (:pr:`28`)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from astropy.io import fits
import numpy as np
from numpy.testing import assert_array_equal
import pytest

from xarrayfits.grid import AffineGrid


@pytest.fixture(params=[(10, 20, 30)])
def header(request):
# Reverse into FORTRAN order
rev_dims = list(reversed(request.param))
naxes = {f"NAXIS{d + 1}": s for d, s in enumerate(rev_dims)}
crpix = {f"CRPIX{d + 1}": 5 + d for d, _ in enumerate(rev_dims)}
crval = {f"CRVAL{d + 1}": 1.0 + d for d, _ in enumerate(rev_dims)}
cdelt = {f"CDELT{d + 1}": 2.0 + d for d, _ in enumerate(rev_dims)}
cunit = {f"CUNIT{d + 1}": f"UNIT-{len(rev_dims) - d}" for d in range(len(rev_dims))}
ctype = {f"CTYPE{d + 1}": f"TYPE-{len(rev_dims) - d}" for d in range(len(rev_dims))}
cname = {f"CNAME{d + 1}": f"NAME-{len(rev_dims) - d}" for d in range(len(rev_dims))}

return fits.Header(
{
"NAXIS": len(request.param),
**naxes,
**crpix,
**crval,
**cdelt,
**cname,
**ctype,
**cunit,
}
)


def test_affine_grid(header):
grid = AffineGrid(header)
ndims = grid.ndims
assert ndims == header["NAXIS"]
assert grid.naxis == [10, 20, 30]
assert grid.crpix == [7, 6, 5]
assert grid.crval == [3.0, 2.0, 1.0]
assert grid.cdelt == [4.0, 3.0, 2.0]
assert grid.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)]
assert grid.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)]
assert grid.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)]

# Worked coordinate example
assert_array_equal(grid.coords(0), 3.0 + (np.arange(1, 10 + 1) - 7) * 4.0)
assert_array_equal(grid.coords(1), 2.0 + (np.arange(1, 20 + 1) - 6) * 3.0)
assert_array_equal(grid.coords(2), 1.0 + (np.arange(1, 30 + 1) - 5) * 2.0)

# More automatic version
for d in range(ndims):
assert_array_equal(
grid.coords(d),
grid.crval[d]
+ (np.arange(1, grid.naxis[d] + 1) - grid.crpix[d]) * grid.cdelt[d],
)
4 changes: 2 additions & 2 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def beam_cube(tmp_path_factory):
def test_name_prefix(beam_cube):
"""Test specification of a name prefix"""
(xds,) = xds_from_fits(beam_cube, prefix="beam")
assert xds.beam0.dims == ("beam0-0", "beam0-1", "beam0-2")
assert xds.beam0.dims == ("X0", "Y0", "FREQ0")


def test_beam_creation(beam_cube):
Expand All @@ -162,7 +162,7 @@ def test_beam_creation(beam_cube):
cmp_data = cmp_data.reshape(xds.hdu0.shape)
assert_array_equal(xds.hdu0.data, cmp_data)
assert xds.hdu0.data.shape == (257, 257, 32)
assert xds.hdu0.dims == ("hdu0-0", "hdu0-1", "hdu0-2")
assert xds.hdu0.dims == ("X0", "Y0", "FREQ0")
sjperkins marked this conversation as resolved.
Show resolved Hide resolved
assert xds.hdu0.attrs == {
"header": {
"BITPIX": -64,
Expand Down
20 changes: 12 additions & 8 deletions xarrayfits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import xarray as xr

from xarrayfits.grid import AffineGrid
from xarrayfits.fits_proxy import FitsProxy

log = logging.getLogger("xarray-fits")
Expand Down Expand Up @@ -160,18 +161,17 @@ def array_from_fits_hdu(

shape = []
flat_chunks = []
grid = AffineGrid(hdu.header)

# At this point we are dealing with FORTRAN ordered axes
for i in range(naxis):
ax_key = f"NAXIS{naxis - i}"
ax_shape = hdu.header[ax_key]
shape.append(ax_shape)
# Determine shapes and apply chunking
for i in range(grid.ndims):
shape.append(grid.naxis[i])

try:
# Try add existing chunking strategies to the list
flat_chunks.append(chunks[i])
except KeyError:
flat_chunks.append(ax_shape)
flat_chunks.append(grid.naxis[i])

array = generate_slice_gets(
fits_proxy,
Expand All @@ -181,9 +181,13 @@ def array_from_fits_hdu(
tuple(flat_chunks),
)

dims = tuple(f"{prefix}{hdu_index}-{i}" for i in range(0, naxis))
dims = tuple(
f"{name}{hdu_index}" if (name := grid.name(i)) else f"{prefix}{hdu_index}-{i}"
for i in range(grid.ndims)
)
coords = {d: (d, grid.coords(i)) for i, d in enumerate(dims)}
attrs = {"header": {k: v for k, v in sorted(hdu.header.items())}}
return xr.DataArray(array, dims=dims, attrs=attrs)
return xr.DataArray(array, dims=dims, coords=coords, attrs=attrs)


def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None):
Expand Down
74 changes: 74 additions & 0 deletions xarrayfits/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from collections.abc import Mapping
import numpy as np

HEADER_PREFIXES = ["NAXIS", "CTYPE", "CRPIX", "CRVAL", "CDELT", "CUNIT", "CNAME"]


class UndefinedGridError(ValueError):
pass


def property_factory(prefix: str):
def impl(self):
return getattr(self, f"_{prefix}")

return property(impl)


class AffineGridMetaclass(type):
def __new__(cls, name, bases, dct):
for prefix in (p.lower() for p in HEADER_PREFIXES):
dct[prefix] = property_factory(prefix)
return type.__new__(cls, name, bases, dct)


class AffineGrid(metaclass=AffineGridMetaclass):
"""Presents a C-ordered view over FITS Header grid attributes"""

def __init__(self, header: Mapping):
self._ndims = ndims = header["NAXIS"]
axr = tuple(range(1, ndims + 1))
h = header

# Read headers into C-order
for prefix in HEADER_PREFIXES:
values = reversed([header.get(f"{prefix}{n}") for n in axr])
values = [s.strip() if isinstance(s, str) else s for s in values]
setattr(self, f"_{prefix.lower()}", values)

# We must have all NAXIS
for i, a in enumerate(self.naxis):
if a is None:
raise UndefinedGridError(f"NAXIS{ndims - i} undefined")

# Fill in any missing CRVAL
self._crval = [0.0 if v is None else v for v in self._crval]
# Fill in any missing CRPIX
self._crpix = [1 if p is None else p for p in self._crpix]
# Fill in any missing CDELT
self._cdelt = [1.0 if d is None else d for d in self._cdelt]
sjperkins marked this conversation as resolved.
Show resolved Hide resolved

self._grid = []

for d in range(ndims):
pixels = np.arange(1, self._naxis[d] + 1, dtype=np.float64)
self._grid.append(
(pixels - self._crpix[d]) * self._cdelt[d] + self._crval[d]
)

@property
def ndims(self):
return self._ndims

def name(self, dim: int):
"""Return a name for dimension :code:`dim`"""
if result := self.cname[dim]:
return result
elif result := self.ctype[dim]:
return result
else:
return None
Copy link
Member Author

@sjperkins sjperkins Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tries to find the axis CNAME first, else falls back to CTYPE


def coords(self, dim: int):
"""Return the affine coordinates for dimension :code:`dim`"""
return self._grid[dim]