Skip to content

Commit

Permalink
Add basic Affine Grid Coordinates to xarray Datasets (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Apr 15, 2024
1 parent 5cc48da commit ca06461
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 99 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add basic Affine Grid coordinates to xarray datasets (:pr:`35`)
* Constrain dask versions (:pr:`34`)
* Specify dtype during chunk normalisation (:pr:`33`)
* Configure dependabot for github actions (:pr:`28`)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from astropy.io import fits
import numpy as np
from numpy.testing import assert_array_equal
import pytest

from xarrayfits.grid import AffineGrid


@pytest.fixture(params=[(10, 20, 30)])
def header(request):
# Reverse into FORTRAN order
rev_dims = list(reversed(request.param))
naxes = {f"NAXIS{d + 1}": s for d, s in enumerate(rev_dims)}
crpix = {f"CRPIX{d + 1}": 5 + d for d, _ in enumerate(rev_dims)}
crval = {f"CRVAL{d + 1}": 1.0 + d for d, _ in enumerate(rev_dims)}
cdelt = {f"CDELT{d + 1}": 2.0 + d for d, _ in enumerate(rev_dims)}
cunit = {f"CUNIT{d + 1}": f"UNIT-{len(rev_dims) - d}" for d in range(len(rev_dims))}
ctype = {f"CTYPE{d + 1}": f"TYPE-{len(rev_dims) - d}" for d in range(len(rev_dims))}
cname = {f"CNAME{d + 1}": f"NAME-{len(rev_dims) - d}" for d in range(len(rev_dims))}

return fits.Header(
{
"NAXIS": len(request.param),
**naxes,
**crpix,
**crval,
**cdelt,
**cname,
**ctype,
**cunit,
}
)


def test_affine_grid(header):
grid = AffineGrid(header)
ndims = grid.ndims
assert ndims == header["NAXIS"]
assert grid.naxis == [10, 20, 30]
assert grid.crpix == [7, 6, 5]
assert grid.crval == [3.0, 2.0, 1.0]
assert grid.cdelt == [4.0, 3.0, 2.0]
assert grid.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)]
assert grid.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)]
assert grid.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)]

# Worked coordinate example
assert_array_equal(grid.coords(0), 3.0 + (np.arange(1, 10 + 1) - 7) * 4.0)
assert_array_equal(grid.coords(1), 2.0 + (np.arange(1, 20 + 1) - 6) * 3.0)
assert_array_equal(grid.coords(2), 1.0 + (np.arange(1, 30 + 1) - 5) * 2.0)

# More automatic version
for d in range(ndims):
assert_array_equal(
grid.coords(d),
grid.crval[d]
+ (np.arange(1, grid.naxis[d] + 1) - grid.crpix[d]) * grid.cdelt[d],
)
190 changes: 126 additions & 64 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,6 @@
from xarrayfits.fits_proxy import FitsProxy


@pytest.fixture(scope="session")
def multiple_files(tmp_path_factory):
path = tmp_path_factory.mktemp("globbing")
shape = (10, 10)
data = np.arange(np.prod(shape), dtype=np.float64)
data = data.reshape(shape)

filenames = []

for i in range(3):
filename = str(path / f"data-{i}.fits")
filenames.append(filename)
primary_hdu = fits.PrimaryHDU(data)
primary_hdu.writeto(filename, overwrite=True)

return filenames


def multiple_dataset_tester(datasets):
assert len(datasets) == 3

for xds in datasets:
expected = np.arange(np.prod(xds.hdu0.shape), dtype=np.float64)
expected = expected.reshape(xds.hdu0.shape)
assert_array_equal(xds.hdu0.data, expected)

combined = xarray.concat(datasets, dim="hdu0-0")
assert_array_equal(combined.hdu0.data, np.concatenate([expected] * 3, axis=0))
assert combined.hdu0.dims == ("hdu0-0", "hdu0-1")

combined = xarray.concat(datasets, dim="hdu0-1")
assert_array_equal(combined.hdu0.data, np.concatenate([expected] * 3, axis=1))
assert combined.hdu0.dims == ("hdu0-0", "hdu0-1")

tds = [ds.expand_dims(dim="time", axis=0) for ds in datasets]
combined = xarray.concat(tds, dim="time")
assert_array_equal(combined.hdu0.data, np.stack([expected] * 3, axis=0))
assert combined.hdu0.dims == ("time", "hdu0-0", "hdu0-1")


def test_list_files(multiple_files):
datasets = xds_from_fits(multiple_files)
return multiple_dataset_tester(datasets)


def test_globbing(multiple_files):
path, _ = os.path.split(multiple_files[0])
datasets = xds_from_fits(f"{path}{os.sep}data*.fits")
return multiple_dataset_tester(datasets)


@pytest.fixture(scope="session")
def beam_cube(tmp_path_factory):
frequency = np.linspace(0.856e9, 0.856e9 * 2, 32, endpoint=True)
Expand Down Expand Up @@ -150,20 +99,133 @@ def beam_cube(tmp_path_factory):
yield filename


def test_name_prefix(beam_cube):
"""Test specification of a name prefix"""
(xds,) = xds_from_fits(beam_cube, prefix="beam")
assert xds.beam0.dims == ("beam0-0", "beam0-1", "beam0-2")
@pytest.fixture(scope="session")
def multiple_hdu_file(tmp_path_factory):
ctypes = ["X", "Y", "FREQ", "STOKES"]

def make_hdu(hdu_cls, shape):
data = np.arange(np.prod(shape), dtype=np.float64)
data = data.reshape(shape)
header = {
# "SIMPLE": True,
# "BITPIX": -64,
# "NAXIS": len(data),
# **{f"NAXIS{data.ndim - i}": d for i, d in enumerate(data.shape)},
**{f"CTYPE{data.ndim - i}": ctypes[i] for i in range(data.ndim)},
}

return hdu_cls(data, header=fits.Header(header))

hdu1 = make_hdu(fits.PrimaryHDU, (10, 10))
hdu2 = make_hdu(fits.ImageHDU, (10, 20, 30))
hdu3 = make_hdu(fits.ImageHDU, (30, 40, 50))

filename = str(tmp_path_factory.mktemp("multihdu") / "data.fits")
hdu_list = fits.HDUList([hdu1, hdu2, hdu3])
hdu_list.writeto(filename, overwrite=True)

return filename


@pytest.fixture(scope="session")
def multiple_files(tmp_path_factory):
path = tmp_path_factory.mktemp("globbing")
shape = (10, 10)
data = np.arange(np.prod(shape), dtype=np.float64)
data = data.reshape(shape)

filenames = []

for i in range(3):
filename = str(path / f"data-{i}.fits")
filenames.append(filename)
primary_hdu = fits.PrimaryHDU(data)
primary_hdu.writeto(filename, overwrite=True)

return filenames


def multiple_dataset_tester(datasets):
assert len(datasets) == 3

for xds in datasets:
expected = np.arange(np.prod(xds.hdu.shape), dtype=np.float64)
expected = expected.reshape(xds.hdu.shape)
assert_array_equal(xds.hdu.data, expected)

combined = xarray.concat(datasets, dim="hdu-0")
assert_array_equal(combined.hdu.data, np.concatenate([expected] * 3, axis=0))
assert combined.hdu.dims == ("hdu-0", "hdu-1")

combined = xarray.concat(datasets, dim="hdu-1")
assert_array_equal(combined.hdu.data, np.concatenate([expected] * 3, axis=1))
assert combined.hdu.dims == ("hdu-0", "hdu-1")

tds = [ds.expand_dims(dim="time", axis=0) for ds in datasets]
combined = xarray.concat(tds, dim="time")
assert_array_equal(combined.hdu.data, np.stack([expected] * 3, axis=0))
assert combined.hdu.dims == ("time", "hdu-0", "hdu-1")


def test_list_files(multiple_files):
datasets = xds_from_fits(multiple_files)
return multiple_dataset_tester(datasets)


def test_globbing(multiple_files):
path, _ = os.path.split(multiple_files[0])
datasets = xds_from_fits(f"{path}{os.sep}data*.fits")
return multiple_dataset_tester(datasets)


def test_multiple_unnamed_hdus(multiple_hdu_file):
"""Test hdu requests with hdu indexes"""
(ds,) = xds_from_fits(multiple_hdu_file, hdus=0)
assert len(ds.data_vars) == 1
assert ds.hdu.shape == (10, 10)
assert ds.hdu.dims == ("X", "Y")

(ds,) = xds_from_fits(multiple_hdu_file, hdus=[0, 2])
assert len(ds.data_vars) == 2

assert ds.hdu0.shape == (10, 10)
assert ds.hdu0.dims == ("hdu0-X", "hdu0-Y")

assert ds.hdu2.shape == (30, 40, 50)
assert ds.hdu2.dims == ("hdu2-X", "hdu2-Y", "hdu2-FREQ")


def test_multiple_named_hdus(multiple_hdu_file):
"""Test hdu requests with named hdus"""
(ds,) = xds_from_fits(multiple_hdu_file, hdus={0: "beam"})
assert ds.beam.dims == ("X", "Y")
assert ds.beam.shape == (10, 10)

(ds,) = xds_from_fits(multiple_hdu_file, hdus=["beam"])
assert ds.beam.dims == ("X", "Y")
assert ds.beam.shape == (10, 10)

(ds,) = xds_from_fits(multiple_hdu_file, hdus={0: "beam", 2: "3C147"})
assert ds.beam.dims == ("beam-X", "beam-Y")
assert ds.beam.shape == (10, 10)
assert ds["3C147"].dims == ("3C147-X", "3C147-Y", "3C147-FREQ")
assert ds["3C147"].shape == (30, 40, 50)

(ds,) = xds_from_fits(multiple_hdu_file, hdus=["beam", "3C147"])
assert ds.beam.dims == ("beam-X", "beam-Y")
assert ds.beam.shape == (10, 10)
assert ds["3C147"].dims == ("3C147-X", "3C147-Y", "3C147-FREQ")
assert ds["3C147"].shape == (10, 20, 30)


def test_beam_creation(beam_cube):
(xds,) = xds_from_fits(beam_cube)
cmp_data = np.arange(np.prod(xds.hdu0.shape), dtype=np.float64)
cmp_data = cmp_data.reshape(xds.hdu0.shape)
assert_array_equal(xds.hdu0.data, cmp_data)
assert xds.hdu0.data.shape == (257, 257, 32)
assert xds.hdu0.dims == ("hdu0-0", "hdu0-1", "hdu0-2")
assert xds.hdu0.attrs == {
cmp_data = np.arange(np.prod(xds.hdu.shape), dtype=np.float64)
cmp_data = cmp_data.reshape(xds.hdu.shape)
assert_array_equal(xds.hdu.data, cmp_data)
assert xds.hdu.data.shape == (257, 257, 32)
assert xds.hdu.dims == ("X", "Y", "FREQ")
assert xds.hdu.attrs == {
"header": {
"BITPIX": -64,
"EQUINOX": 2000.0,
Expand Down Expand Up @@ -201,9 +263,9 @@ def test_distributed(beam_cube):
stack.enter_context(Client(cluster))

(xds,) = xds_from_fits(beam_cube, chunks={0: 100, 1: 100, 2: 15})
expected = np.arange(np.prod(xds.hdu0.shape)).reshape(xds.hdu0.shape)
assert_array_equal(expected, xds.hdu0.data)
assert xds.hdu0.data.chunks == ((100, 100, 57), (100, 100, 57), (15, 15, 2))
expected = np.arange(np.prod(xds.hdu.shape)).reshape(xds.hdu.shape)
assert_array_equal(expected, xds.hdu.data)
assert xds.hdu.data.chunks == ((100, 100, 57), (100, 100, 57), (15, 15, 2))


def test_memory_mapped(beam_cube):
Expand Down
Loading

0 comments on commit ca06461

Please sign in to comment.