Skip to content

Commit

Permalink
start of new version
Browse files Browse the repository at this point in the history
  • Loading branch information
jmbhughes committed Aug 11, 2024
1 parent cc937ef commit a89f9f3
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 142 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"url": "https://github.com/punch-mission/regularizepsf",
"icon": "fa-brands fa-github",
"type": "fontawesome",
}
},
],
"show_nav_level": 1,
"show_toc_level": 3,
Expand All @@ -57,4 +57,4 @@


autoapi_dirs = ["../../regularizepsf"]
autoapi_python_class_content = 'both'
autoapi_python_class_content = "both"
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ dev = ["regularizepsf[test, docs]", "pre-commit"]
[tool.ruff]
exclude = ['tests/*']
line-length = 120
# lint.select = ["ALL"]
lint.ignore = [ "FBT001", "FBT002", "ANN401"]

#[tool.ruff.lint]
#select = ["NPY201"]
Expand Down
97 changes: 59 additions & 38 deletions regularizepsf/corrector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import abc
from typing import Any, Tuple, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Any

import h5py
import numpy as np
Expand All @@ -12,6 +11,9 @@
from regularizepsf.helper import _correct_image, _precalculate_ffts
from regularizepsf.psf import PointSpreadFunctionABC, SimplePSF, VariedPSF

if TYPE_CHECKING:
from pathlib import Path


class CorrectorABC(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand All @@ -26,12 +28,13 @@ def save(self, path: str | Path) -> None:
Returns
-------
None
"""

@classmethod
@abc.abstractmethod
def load(cls, path: str | Path) -> CorrectorABC:
"""Loads a model from the path
"""Loads a model from the path.
Parameters
----------
Expand All @@ -46,7 +49,7 @@ def load(cls, path: str | Path) -> CorrectorABC:
@abc.abstractmethod
def correct_image(self, image: np.ndarray, size: int,
alpha: float = 0.5, epsilon: float = 0.05) -> np.ndarray:
"""PSF correct an image according to the model
"""PSF correct an image according to the model.
Parameters
----------
Expand All @@ -65,11 +68,12 @@ def correct_image(self, image: np.ndarray, size: int,
-------
np.ndarray
a image that has been PSF corrected
"""

@abc.abstractmethod
def simulate_observation(self, image: np.ndarray) -> np.ndarray:
"""Simulates on a star field what an observation using this PSF looks like
"""Simulates on a star field what an observation using this PSF looks like.
Parameters
----------
Expand All @@ -80,46 +84,48 @@ def simulate_observation(self, image: np.ndarray) -> np.ndarray:
-------
np.ndarray
an image with the PSF applied
"""


class FunctionalCorrector(CorrectorABC):
"""
A version of the PSF corrector that stores the model as a set of functions.
"""A version of the PSF corrector that stores the model as a set of functions.
For the actual correction, the functions must first
be evaluated to an ArrayCorrector.
"""

def __init__(self, psf: PointSpreadFunctionABC,
target_model: SimplePSF | None) -> None:
"""Initialize a FunctionalCorrector
"""Initialize a FunctionalCorrector.
Parameters
----------
psf : SimplePSF or VariedPSF
the model describing the psf for each patch of the image
target_model : SimplePSF or None
the target PSF to use to establish uniformity across the image
"""
self._psf: PointSpreadFunctionABC = psf
self._variable: bool = isinstance(self._psf, VariedPSF)
self._target_model: SimplePSF = target_model

@property
def is_variable(self) -> bool:
"""
Returns
"""Returns
-------
bool
True if the PSF model is varied (changes across the field-of-view)
and False otherwise
"""
return self._variable

def evaluate_to_array_form(self,
x: np.ndarray,
y: np.ndarray,
size: int) -> ArrayCorrector:
"""Evaluates a FunctionalCorrector to an ArrayCorrector
"""Evaluates a FunctionalCorrector to an ArrayCorrector.
Parameters
----------
Expand All @@ -134,17 +140,23 @@ def evaluate_to_array_form(self,
-------
ArrayCorrector
an array evaluated form of this PSF corrector
"""
if size % 2 != 0:
raise InvalidSizeError(f"size must be even. Found size={size}.")
msg = f"size must be even. Found size={size}."
raise InvalidSizeError(msg)

image_x, image_y = np.meshgrid(np.arange(size), np.arange(size))
evaluations = {}
for xx in x:
for yy in y:
evaluations[(xx, yy)] = self._psf(image_x, image_y)

target_evaluation = self._target_model(image_x, image_y) if self._target_model else np.ones((size, size))
# target_evaluation = self._target_model(image_x, image_y) if self._target_model else np.ones((size, size))
target_evaluation = {}
for xx in x:
for yy in y:
target_evaluation[(xx, yy)] = self._target_model(image_x, image_y) if self._target_model else np.ones((size, size))
return ArrayCorrector(evaluations, target_evaluation)

def correct_image(self, image: np.ndarray, size: int,
Expand All @@ -158,14 +170,16 @@ def correct_image(self, image: np.ndarray, size: int,
epsilon=epsilon)

def save(self, path: str) -> None:
raise NotImplementedError("You cannot save a FunctionalCorrector.")
msg = "You cannot save a FunctionalCorrector."
raise NotImplementedError(msg)

@classmethod
def load(cls, path: str) -> FunctionalCorrector:
raise NotImplementedError("You cannot load a FunctionalCorrector.")
msg = "You cannot load a FunctionalCorrector."
raise NotImplementedError(msg)

def simulate_observation(self, image: np.ndarray, size: int) -> np.ndarray:
"""Simulates on a star field what an observation using this PSF looks like
"""Simulates on a star field what an observation using this PSF looks like.
Parameters
----------
Expand All @@ -178,6 +192,7 @@ def simulate_observation(self, image: np.ndarray, size: int) -> np.ndarray:
-------
np.ndarray
an image with the PSF applied
"""
corners = calculate_covering(image.shape, size)
array_corrector = self.evaluate_to_array_form(corners[:, 0],
Expand All @@ -186,12 +201,13 @@ def simulate_observation(self, image: np.ndarray, size: int) -> np.ndarray:
return array_corrector.simulate_observation(image)



class ArrayCorrector(CorrectorABC):
""" A PSF corrector that is evaluated as array patches
"""
"""A PSF corrector that is evaluated as array patches."""

def __init__(self, evaluations: dict[Any, np.ndarray],
target_evaluation: np.ndarray) -> None:
"""Initialize an ArrayCorrector
target_evaluations: dict[Any, np.ndarray]) -> None:
"""Initialize an ArrayCorrector.
Parameters
----------
Expand All @@ -200,19 +216,25 @@ def __init__(self, evaluations: dict[Any, np.ndarray],
keys should be (x, y) of the lower left
pixel of each patch. values should be the `np.ndarray`
that corresponds to that patch
target_evaluation : np.ndarray
target_evaluations : np.ndarray
evaluated version of the target PSF
"""
self._evaluation_points: list[Any] = list(evaluations.keys())

if not isinstance(evaluations[self._evaluation_points[0]], np.ndarray):
raise TypeError(f"Individual evaluations must be numpy arrays. "
f"Found {type(evaluations[self._evaluation_points[0]])}.")
msg = (
f"Individual evaluations must be numpy arrays. "
f"Found {type(evaluations[self._evaluation_points[0]])}."
)
raise TypeError(msg)
if len(evaluations[self._evaluation_points[0]].shape) != 2:
raise InvalidSizeError("PSF evaluations must be 2-D numpy arrays.")
msg = "PSF evaluations must be 2-D numpy arrays."
raise InvalidSizeError(msg)
self._size = evaluations[self._evaluation_points[0]].shape[0]
if self._size % 2 != 0:
raise InvalidSizeError(f"Size must be even. Found {self._size}")
msg = f"Size must be even. Found {self._size}"
raise InvalidSizeError(msg)

self._evaluations: dict[Any, np.ndarray] = evaluations
for (x, y), evaluation in self._evaluations.items():
Expand All @@ -222,14 +244,15 @@ def __init__(self, evaluations: dict[Any, np.ndarray],
f"Found {evaluation.shape} at {(x, y)}.")
raise EvaluatedModelInconsistentSizeError(msg)

self._target_evaluation = target_evaluation
if self._target_evaluation.shape != (self._size, self._size):
msg = "The target and evaluations must have the same shape."
raise EvaluatedModelInconsistentSizeError(msg)
self._target_evaluations = target_evaluations
# if self._target_evaluation.shape != (self._size, self._size):
# msg = "The target and evaluations must have the same shape."
# raise EvaluatedModelInconsistentSizeError(msg)

normalized_values = np.array(
[v / v.sum() for v in self._evaluations.values()], dtype=float)
normalized_target = target_evaluation / target_evaluation.sum()
normalized_target = np.array(
[v / v.sum() for v in self._target_evaluations.values()], dtype=float)
self.target_fft, self.psf_i_fft = _precalculate_ffts(
normalized_target, normalized_values)

Expand All @@ -241,27 +264,25 @@ def evaluations(self) -> dict[Any, np.ndarray]:
def evaluation_points(self) -> list:
return self._evaluation_points

def correct_image(self, image: np.ndarray, size: Optional[int] = None, # noqa: ARG002
def correct_image(self, image: np.ndarray, size: int | None = None, # noqa: ARG002
alpha: float = 0.5, epsilon: float = 0.05) -> np.ndarray:
if not all(img_dim_i >= psf_dim_i for img_dim_i, psf_dim_i in zip(image.shape,
(self._size,
self._size))):
self._size), strict=False)):
msg = "The image must be at least as large as the PSFs in all dimensions"
raise InvalidSizeError(msg)

x = np.array([x for x, _ in self._evaluations], dtype=int)
y = np.array([y for _, y in self._evaluations], dtype=int)

return _correct_image(image,
self.target_fft,
x, y,
self.psf_i_fft, alpha, epsilon)
return _correct_image(image, self.psf_i_fft, self.target_fft, x, y, alpha, epsilon)

def __getitem__(self, xy: Tuple[int, int]) -> np.ndarray:
def __getitem__(self, xy: tuple[int, int]) -> np.ndarray:
if xy in self._evaluation_points:
return self._evaluations[xy]
else:
raise UnevaluatedPointError(f"Model not evaluated at {xy}.")
msg = f"Model not evaluated at {xy}."
raise UnevaluatedPointError(msg)

def save(self, path: str) -> None:
with h5py.File(path, "w") as f:
Expand Down
Loading

0 comments on commit a89f9f3

Please sign in to comment.