Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fvpsf #403

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions requirements.github_actions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ scipy
astropy
matplotlib
jupyter
numba
joblib

docutils
requests
Expand Down
1 change: 1 addition & 0 deletions requirements.readthedocs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ numpy>=1.16
scipy
matplotlib
astropy
numba

docutils
requests
Expand Down
236 changes: 215 additions & 21 deletions scopesim/effects/psf_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Tuple, List

import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage as spi
from scipy.interpolate import RectBivariateSpline, griddata
from scipy.ndimage import zoom
import numpy.typing as npt
from astropy import units as u
from astropy.convolution import Gaussian2DKernel
from astropy.io import fits
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from numba import njit, prange
from scipy import ndimage as spi
from scipy.interpolate import RectBivariateSpline, griddata
from scipy.ndimage import zoom

from .. import rc, utils
from .. import utils
from ..optics import image_plane_utils as imp_utils


Expand Down Expand Up @@ -81,9 +84,7 @@ def nmrms_from_strehl_and_wavelength(strehl, wavelength, strehl_hdu,
return nm


def make_strehl_map_from_table(tbl, pixel_scale=1*u.arcsec):


def make_strehl_map_from_table(tbl, pixel_scale=1 * u.arcsec):
# pixel_scale = utils.quantify(pixel_scale, u.um).to(u.deg)
# coords = np.array([tbl["x"], tbl["y"]]).T
#
Expand All @@ -102,7 +103,7 @@ def make_strehl_map_from_table(tbl, pixel_scale=1*u.arcsec):

hdr = imp_utils.header_from_list_of_xy(np.array([-25, 25]) / 3600.,
np.array([-25, 25]) / 3600.,
pixel_scale=1/3600)
pixel_scale=1 / 3600)

map_hdu = fits.ImageHDU(header=hdr, data=map)

Expand All @@ -114,19 +115,19 @@ def rescale_kernel(image, scale_factor, spline_order=None):
spline_order = utils.from_currsys("!SIM.computing.spline_order")
sum_image = np.sum(image)
image = zoom(image, scale_factor, order=spline_order)
image = np.nan_to_num(image, copy=False) # numpy version >=1.13
image = np.nan_to_num(image, copy=False) # numpy version >=1.13

# Re-centre kernel
im_shape = image.shape
dy, dx = np.divmod(np.argmax(image), im_shape[1]) - np.array(im_shape) // 2
if dy > 0:
image = image[2*dy:, :]
image = image[2 * dy:, :]
elif dy < 0:
image = image[:2*dy, :]
image = image[:2 * dy, :]
if dx > 0:
image = image[:, 2*dx:]
image = image[:, 2 * dx:]
elif dx < 0:
image = image[:, :2*dx]
image = image[:, :2 * dx]

sum_new_image = np.sum(image)
image *= sum_image / sum_new_image
Expand All @@ -139,15 +140,14 @@ def cutout_kernel(image, fov_header):
xcen, ycen = 0.5 * w, 0.5 * h
dx = 0.5 * fov_header["NAXIS1"]
dy = 0.5 * fov_header["NAXIS2"]
x0, x1 = max(0, int(xcen-dx)), min(w, int(xcen+dx))
y0, y1 = max(0, int(ycen-dy)), min(w, int(ycen+dy))
x0, x1 = max(0, int(xcen - dx)), min(w, int(xcen + dx))
y0, y1 = max(0, int(ycen - dy)), min(w, int(ycen + dy))
image_cutout = image[y0:y1, x0:x1]

return image_cutout


def get_strehl_cutout(fov_header, strehl_imagehdu):

image = np.zeros((fov_header["NAXIS2"], fov_header["NAXIS1"]))
canvas_hdu = fits.ImageHDU(header=fov_header, data=image)
canvas_hdu = imp_utils.add_imagehdu_to_imagehdu(strehl_imagehdu,
Expand Down Expand Up @@ -197,7 +197,7 @@ def get_psf_wave_exts(hdu_list, wave_key="WAVE0"):
def get_total_wfe_from_table(tbl):
wfes = utils.quantity_from_table("wfe_rms", tbl, "um")
n_surfs = tbl["n_surfaces"]
total_wfe = np.sum(n_surfs * wfes**2)**0.5
total_wfe = np.sum(n_surfs * wfes ** 2) ** 0.5

return total_wfe

Expand All @@ -217,7 +217,7 @@ def wfe2strehl(wfe, wave):
wave = utils.quantify(wave, u.um)
wfe = utils.quantify(wfe, u.um)
x = 2 * 3.1415926526 * wfe / wave
strehl = np.exp(-x**2)
strehl = np.exp(-x ** 2)
return strehl


Expand Down Expand Up @@ -269,6 +269,7 @@ def rotational_blur(image, angle):

return image_rot / n_angles


def get_bkg_level(obj, bg_w):
"""
Determine the background level of image or cube slices
Expand All @@ -289,7 +290,7 @@ def get_bkg_level(obj, bg_w):
else:
mask = np.zeros_like(obj, dtype=np.bool8)
if bg_w > 0:
mask[bg_w:-bg_w,bg_w:-bg_w] = True
mask[bg_w:-bg_w, bg_w:-bg_w] = True
bkg_level = np.ma.median(np.ma.masked_array(obj, mask=mask))

elif obj.ndim == 3:
Expand All @@ -305,3 +306,196 @@ def get_bkg_level(obj, bg_w):
else:
raise ValueError("Unsupported dimension:", obj.ndim)
return bkg_level


@njit()
def kernel_grid_linear_interpolation(kernel_grid: npt.NDArray, position: npt.NDArray) -> npt.NDArray:
"""Bi-linear interpolation of a grid of 2D arrays at a given position.

This function interpolates a grid of 2D arrays at a given position using a weighted mean (i.e. bi-linear
interpolation). The grid object should be of shape (M, N, I, J), with MxN the shape of the grid of arrays and IxJ
the shape of the array at each point.

Parameters
----------
kernel_grid : npt.NDArray
An array with shape `(M, N, I, J)` defining a `MxN` grid of 2D arrays to be interpolated.
position: npt.NDArray
An array containing the position in the `MxN` at which the resulting 2D array is computed.

Returns
-------
npt.NDArray
An IxJ array at the given position obtained by interpolation.
"""
# Grid and kernel dimensions
grid_i, grid_j, kernel_i, kernel_j = kernel_grid.shape

# Find the closest grid points to the given position
x, y = position
x0 = int(x)
y0 = int(y)
x1 = x0 + 1
y1 = y0 + 1

# Get the four closest arrays to the given position
psf00 = kernel_grid[x0, y0, :, :]
psf01 = kernel_grid[x0, y1, :, :]
psf10 = kernel_grid[x1, y0, :, :]
psf11 = kernel_grid[x1, y1, :, :]

# Define the weights for each grid point
dx = x - x0
dy = y - y0
inv_dx = 1 - dx
inv_dy = 1 - dy

a = inv_dx * inv_dy
b = dx * inv_dy
c = inv_dx * dy
d = dx * dy

# Construct support array and retrieve pixel values by interpolating
output = np.empty((kernel_i, kernel_j), dtype=kernel_grid.dtype)
for i in range(kernel_i):
for j in range(kernel_j):
output[i, j] = (
a * psf00[i, j]
+ b * psf01[i, j]
+ c * psf10[i, j]
+ d * psf11[i, j]
)
return output


@njit(parallel=True)
def _convolve2d_varying_kernel(image: npt.NDArray,
kernel_grid: npt.NDArray,
coordinates: Tuple[npt.NDArray, npt.NDArray],
interpolator) -> npt.NDArray:
"""(Helper) Convolve an image with a spatially-varying kernel by interpolating a discrete kernel grid.

Numba JIT function for performing the convolution of an image with a spatially-varying kernel by interpolation of a
kernel grid at each pixel position. Check `convolve2d_varying_kernel` for more information.

Parameters
----------
image: npt.NDArray
The image to be convolved.
kernel_grid : npt.NDArray
An array with shape `(M, N, I, J)` defining an `MxN` grid of 2D kernels.
coordinates : Tuple[npt.ArrayLike, npt.ArrayLike]
A tuple of arrays defining the axis coordinates of each pixel of the image in the kernel grid coordinated in
which the kernel is to be computed.
interpolator
A Numba njit'ted function that performs the interpolation. It's signature should be
`(kernel_grid: npt.NDArray, position: npt.NDArray, check_bounds: bool) -> npt.NDArray`.

Returns
-------
npt.NDArray
The image convolved with the kernel grid interpolated at each pixel.
"""
# [JA] TODO: Allow for kernel center != kernel.shape // 2
# Get image, grid and kernel dimensions
img_i, img_j = image.shape
grid_i, grid_j, kernel_i, kernel_j = kernel_grid.shape

# Add padding to the image (note: Numba doesn't support np.pad)
kernel_ci, kernel_cj = kernel_i // 2, kernel_j // 2
padded_img = np.zeros((img_i + kernel_i - 1, img_j + kernel_j - 1), dtype=image.dtype)
padded_img[kernel_ci:kernel_ci + img_i, kernel_cj:kernel_cj + img_j] = image

# Create output array
output = np.zeros_like(padded_img)
# Compute kernel and convolve for each pixel
for i in prange(img_i):
x = coordinates[0][i]
for j in range(img_j):
pixel_value = image[i, j]
if pixel_value != 0:
y = coordinates[1][j]
# Get kernel for current pixel
position = np.array((x, y))
kernel = interpolator(kernel_grid=kernel_grid,
position=position)

# Apply to image
tmp = np.zeros_like(padded_img)

start_i, start_j = i, j
stop_i, stop_j = start_i + kernel_i, start_j + kernel_j
tmp[start_i:stop_i, start_j:stop_j] += pixel_value * kernel
tmp[start_i:stop_i, start_j:stop_j] = pixel_value * kernel

output += tmp
return output[kernel_ci:kernel_ci + img_i, kernel_cj:kernel_cj + img_j]


def convolve2d_varying_kernel(image: npt.ArrayLike,
kernel_grid: npt.ArrayLike,
coordinates: List[npt.ArrayLike],
*,
mode: str = "linear") -> npt.NDArray:
"""Convolve an image with a spatially-varying kernel by interpolating a discrete kernel grid.

An image is convolved with a spatially-varying kernel, as defined by a discrete kernel grid, by computing, for each
of the image pixels, an effective kernel. The effective kernel is obtained by interpolating the origin kernel grid
at the position of each image pixel.


Parameters
----------
image: npt.Arraylike
The image to be convolved.
kernel_grid : npt.ArrayLike
An array with shape `(M, N, I, J)` defining an `MxN` grid of 2D kernels.
coordinates : List[npt.ArrayLike]
A tuple of arrays defining the axis coordinates of each pixel of the image in the kernel grid coordinated in
which the kernel is to be computed.
mode : str
The interpolation mode to be used to interpolate the convolution kernel (currently only `\"linear\"` - for
bi-linear interpolation - is implemented).

Returns
-------
npt.NDArray
The image convolved with the kernel grid interpolated at each pixel.

Raises
------
ValueError
If the provided axis coordinates are out of bounds with respect to the provided kernel grid.
ValueError
If the provided axis coordinates do not match the image shape.
NotImplementedError
If the interpolation mode (`mode`) is `nearest` (nearest neighbor interpolation).
ValueError
If the interpolation mode (`mode`) is not `nearest` or `linear`.
"""

image = np.array(image)
kernel_grid = np.array(kernel_grid)
x, y = (np.array(axis) for axis in tuple(coordinates))

# Validate coordinates
if np.any((x.max(), y.max()) >= image.shape) or np.any((x.min(), y.min()) < (0, 0)):
raise ValueError("Coordinates out of kernel grid bounds.")

if (x.size, y.size) != image.shape:
raise ValueError("Coordinates provided do not match image shape.")

# Select interpolation mode
mode = str(mode).lower()
if mode == "linear":
interpolation_fn = kernel_grid_linear_interpolation
elif mode == "nearest":
interpolation_fn = None
raise NotImplementedError(f"Mode \'{mode}\' not implemented.")
else:
raise ValueError(f"Invalid interpolation mode \'{mode}\'")

return _convolve2d_varying_kernel(image=image,
kernel_grid=kernel_grid,
coordinates=(x, y),
interpolator=interpolation_fn)
Loading