Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Apr 5, 2024
1 parent d769349 commit a8acad8
Showing 1 changed file with 42 additions and 35 deletions.
77 changes: 42 additions & 35 deletions xarrayfits/grid.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,31 @@
from collections.abc import Mapping
import numpy as np

HEADER_PREFIXES = ["NAXIS", "CTYPE", "CRPIX", "CRVAL", "CDELT", "CUNIT", "CNAME"]


class UndefinedGridError(ValueError):
pass


def property_factory(prefix: str):
def impl(self):
return getattr(self, f"_{prefix}")

return property(impl)


class AffineGridMetaclass(type):
def __new__(cls, name, bases, dct):
for prefix in (p.lower() for p in HEADER_PREFIXES):
dct[prefix] = property_factory(prefix)
return type.__new__(cls, name, bases, dct)


class AffineGrid(metaclass=AffineGridMetaclass):
class AffineGrid:
"""Presents a C-ordered view over FITS Header grid attributes"""

def __init__(self, header: Mapping):
self._ndims = ndims = header["NAXIS"]
axr = tuple(range(1, ndims + 1))
h = header

# Read headers into C-order
for prefix in HEADER_PREFIXES:
values = reversed([header.get(f"{prefix}{n}") for n in axr])
values = [s.strip() if isinstance(s, str) else s for s in values]
setattr(self, f"_{prefix.lower()}", values)

# We must have all NAXIS
for i, a in enumerate(self.naxis):
if a is None:
raise UndefinedGridError(f"NAXIS{ndims - i} undefined")

# Fill in any missing CRVAL
self._crval = [0.0 if v is None else v for v in self._crval]
# Fill in any missing CRPIX
self._crpix = [1 if p is None else p for p in self._crpix]
# Fill in any missing CDELT
self._cdelt = [1.0 if d is None else d for d in self._cdelt]
try:
self._ndims = ndims = h["NAXIS"]
axr = tuple(range(1, ndims + 1))
self._naxis = list(reversed([h[f"NAXIS{n}"] for n in axr]))
except KeyError as e:
raise UndefinedGridError(f"{e} undefined") from e

self._ctype = list(reversed([h.get(f"CTYPE{n}") for n in axr]))
self._crpix = list(reversed([h.get(f"CRPIX{n}", 1) for n in axr]))
self._crval = list(reversed([h.get(f"CRVAL{n}", 0.0) for n in axr]))
self._cdelt = list(reversed([h.get(f"CDELT{n}", 1.0) for n in axr]))
self._cunit = list(reversed([h.get(f"CUNIT{n}") for n in axr]))
self._cname = list(reversed([h.get(f"CNAME{n}") for n in axr]))

self._grid = []

Expand All @@ -56,10 +35,38 @@ def __init__(self, header: Mapping):
(pixels - self._crpix[d]) * self._cdelt[d] + self._crval[d]
)

@property
def naxis(self):
return self._naxis

@property
def ndims(self):
return self._ndims

@property
def ctype(self):
return self._ctype

@property
def crpix(self):
return self._crpix

@property
def crval(self):
return self._crval

@property
def cdelt(self):
return self._cdelt

@property
def cunit(self):
return self._cunit

@property
def cname(self):
return self._cname

def name(self, dim: int):
"""Return a name for dimension :code:`dim`"""
if result := self.cname[dim]:
Expand Down

0 comments on commit a8acad8

Please sign in to comment.