Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boxcar extraction using Trace class #82

Merged
merged 20 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
403 changes: 403 additions & 0 deletions notebook_sandbox/jwst_boxcar/boxcar_extraction.ipynb

Large diffs are not rendered by default.

330 changes: 138 additions & 192 deletions specreduce/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,216 +3,162 @@
from dataclasses import dataclass

import numpy as np
import matplotlib.pyplot as plt

from astropy import units as u
from astropy.nddata import StdDevUncertainty

from specreduce.core import SpecreduceOperation
from specutils import Spectrum1D

__all__ = ['BoxcarExtract']


def _get_boxcar_weights(center, hwidth, npix):
"""
Compute weights given an aperture center, half width, and number of pixels
"""
weights = np.zeros((npix))

# 2d
if type(npix) is not tuple:
# pixels with full weight
fullpixels = [max(0, int(center - hwidth + 1)), min(int(center + hwidth), npix)]
weights[fullpixels[0]:fullpixels[1]] = 1.0

# pixels at the edges of the boxcar with partial weight
if fullpixels[0] > 0:
w = hwidth - (center - fullpixels[0] + 0.5)
if w >= 0:
weights[fullpixels[0] - 1] = w
else:
weights[fullpixels[0]] = 1. + w
if fullpixels[1] < npix:
weights[fullpixels[1]] = hwidth - (fullpixels[1] - center - 0.5)
# 3d
else:
# pixels with full weight
fullpixels_x = [max(0, int(center[1] - hwidth + 1)), min(int(center[1] + hwidth), npix[1])]
fullpixels_y = [max(0, int(center[0] - hwidth + 1)), min(int(center[0] + hwidth), npix[0])]
weights[fullpixels_x[0]:fullpixels_x[1], fullpixels_y[0]:fullpixels_y[1]] = 1.0

# not yet handling pixels at the edges of the boxcar

return weights


def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape):

"""
Create a weight image that defines the desired extraction aperture.

Parameters
----------
trace : Trace
trace object
width : float
width of extraction aperture in pixels
disp_axis : int
dispersion axis
crossdisp_axis : int (2D image) or tuple (3D image)
cross-dispersion axis
image_shape : tuple with 2 or 3 elements
size (shape) of image

Returns
-------
wimage : 2D image
weight image defining the aperture
"""
wimage = np.zeros(image_shape)
hwidth = 0.5 * width

if len(crossdisp_axis) == 1:
# 2d
image_sizes = image_shape[crossdisp_axis[0]]
else:
# 3d
image_shape_array = np.array(image_shape)
crossdisp_axis_array = np.array(crossdisp_axis)
image_sizes = image_shape_array[crossdisp_axis_array]
image_sizes = tuple(image_sizes.tolist())

# loop in dispersion direction and compute weights.
for i in range(image_shape[disp_axis]):
if len(crossdisp_axis) == 1:
# 2d
# TODO trace must handle transposed data (disp_axis == 0)
wimage[:, i] = _get_boxcar_weights(trace[i], hwidth, image_sizes)
else:
# 3d
wimage[i, ::] = _get_boxcar_weights(trace[i], hwidth, image_sizes)

return wimage


@dataclass
class BoxcarExtract(SpecreduceOperation):
"""
Does a standard boxcar extraction

Parameters
----------
img : nddata-compatible image
The input image
trace_object :
The trace of the spectrum to be extracted TODO: define
apwidth : int
The width of the extraction aperture in pixels
skysep : int
The spacing between the aperture and the sky regions
skywidth : int
The width of the sky regions in pixels
skydeg : int
The degree of the polynomial that's fit to the sky
image : nddata-compatible image
image with 2-D spectral image data
width : float
width of extraction aperture in pixels

Returns
-------
spec : `~specutils.Spectrum1D`
The extracted spectrum
skyspec : `~specutils.Spectrum1D`
The sky spectrum used in the extraction process
The extracted 1d spectrum expressed in DN and pixel units
"""
apwidth: int = 8
skysep: int = 3
skywidth: int = 7
skydeg: int = 0

def __call__(self, img, trace_object):
self.last_trace = trace_object
self.last_img = img

if self.apwidth < 1:
raise ValueError('apwidth must be >= 1')
if self.skysep < 1:
raise ValueError('skysep must be >= 1')
if self.skywidth < 1:
raise ValueError('skywidth must be >= 1')

trace_line = trace_object.trace

onedspec = np.zeros_like(trace_line)
skysubflux = np.zeros_like(trace_line)
fluxerr = np.zeros_like(trace_line)
mask = np.zeros_like(trace_line, dtype=bool)

for i in range(0, len(trace_line)):
# if the trace isn't defined at a position (e.g. if it is out of the image boundaries),
# it will be masked. so we propagate that into the output mask and move on.
if np.ma.is_masked(trace_line[i]):
mask[i] = True
continue

# first do the aperture flux
# juuuust in case the trace gets too close to an edge
widthup = self.apwidth / 2.
widthdn = self.apwidth / 2.
if (trace_line[i] + widthup > img.shape[0]):
widthup = img.shape[0] - trace_line[i] - 1.
if (trace_line[i] - widthdn < 0):
widthdn = trace_line[i] - 1.

# extract from box around the trace line
low_end = trace_line[i] - widthdn
high_end = trace_line[i] + widthdn

self._extract_from_box(img, i, low_end, high_end, onedspec)

# now do the sky fit
# Note that we are not including fractional pixels, since we are doing
# a polynomial fit over the sky values.
j1 = self._find_nearest_int(trace_line[i] - self.apwidth/2. -
self.skysep - self.skywidth)
j2 = self._find_nearest_int(trace_line[i] - self.apwidth/2. - self.skysep)
sky_y_1 = np.arange(j1, j2)

j1 = self._find_nearest_int(trace_line[i] + self.apwidth/2. + self.skysep)
j2 = self._find_nearest_int(trace_line[i] + self.apwidth/2. +
self.skysep + self.skywidth)
sky_y_2 = np.arange(j1, j2)

sky_y = np.append(sky_y_1, sky_y_2)

# sky can't be outside image
np_indices = np.indices(img[::, i].shape)
sky_y = np.intersect1d(sky_y, np_indices)

sky_flux = img[sky_y, i]
if (self.skydeg > 0):
# fit a polynomial to the sky in this column
pfit = np.polyfit(sky_y, sky_flux, self.skydeg)
# define the aperture in this column
ap = np.arange(
self._find_nearest_int(trace_line[i] - self.apwidth/2.),
self._find_nearest_int(trace_line[i] + self.apwidth/2.)
)
# evaluate the polynomial across the aperture, and sum
skysubflux[i] = np.nansum(np.polyval(pfit, ap))
elif (self.skydeg == 0):
skysubflux[i] = np.nanmean(sky_flux) * self.apwidth

# finally, compute the error in this pixel
sigma_bkg = np.nanstd(sky_flux) # stddev in the background data
n_bkg = np.float(len(sky_y)) # number of bkgd pixels
n_ap = self.apwidth # number of aperture pixels

# based on aperture phot err description by F. Masci, Caltech:
# http://wise2.ipac.caltech.edu/staff/fmasci/ApPhotUncert.pdf
fluxerr[i] = np.sqrt(
np.nansum(onedspec[i] - skysubflux[i]) + (n_ap + n_ap**2 / n_bkg) * (sigma_bkg**2)
)

img_unit = u.DN
if hasattr(img, 'unit'):
img_unit = img.unit

spec = Spectrum1D(
spectral_axis=np.arange(len(onedspec)) * u.pixel,
flux=onedspec * img_unit,
uncertainty=StdDevUncertainty(fluxerr),
mask=mask
)
skyspec = Spectrum1D(
spectral_axis=np.arange(len(onedspec)) * u.pixel,
flux=skysubflux * img_unit,
mask=mask
)

return spec, skyspec

def _extract_from_box(self, image, wave_index, low_end, high_end, extracted_result):

# compute nearest integer endpoints defining an internal interval,
# and fractional pixel areas that remain outside this interval.
# (taken from the HST STIS pipeline code:
# https://github.com/spacetelescope/hstcal/blob/master/pkg/stis/calstis/cs6/x1dspec.c)
#
# This assumes that the pixel coordinates represent the center of the pixel.
# E.g. pixel at y=15.0 covers the image from y=14.5 to y=15.5

# nearest integer endpoints
j1 = self._find_nearest_int(low_end)
j2 = self._find_nearest_int(high_end)

# fractional pixel areas at the end points
s1 = 0.5 - (low_end - j1)
s2 = 0.5 + high_end - j2

# add up the total flux around the trace_line
extracted_result[wave_index] = np.nansum(image[j1 + 1:j2, wave_index])
extracted_result[wave_index] += np.nansum(image[j1, wave_index]) * s1
extracted_result[wave_index] += np.nansum(image[j2, wave_index]) * s2

def _find_nearest_int(self, end_point):
if (end_point % 1) < 0.5:
return int(end_point)
else:
return int(end_point + 1)

def get_checkplot(self):
trace_line = self.last_trace.line

fig = plt.figure()
plt.imshow(self.last_img, origin='lower', aspect='auto', cmap=plt.cm.Greys_r)
plt.clim(np.percentile(self.last_img, (5, 98)))

plt.plot(np.arange(len(trace_line)), trace_line, c='C0')
plt.fill_between(
np.arange(len(trace_line)),
trace_line + self.apwidth,
trace_line - self.apwidth,
color='C0',
alpha=0.5
)
plt.fill_between(
np.arange(len(trace_line)),
trace_line + self.apwidth + self.skysep,
trace_line + self.apwidth + self.skysep + self.skywidth,
color='C1',
alpha=0.5
)
plt.fill_between(
np.arange(len(trace_line)),
trace_line - self.apwidth - self.skysep,
trace_line - self.apwidth - self.skysep - self.skywidth,
color='C1',
alpha=0.5
)
plt.ylim(
np.min(
trace_line - (self.apwidth + self.skysep + self.skywidth) * 2
),
np.max(
trace_line + (self.apwidth + self.skysep + self.skywidth) * 2
)
)

return fig
# TODO: what is a reasonable default?
# TODO: int or float?
width: int = 5
kecnry marked this conversation as resolved.
Show resolved Hide resolved

# TODO: should disp_axis and crossdisp_axis be defined in the Trace object?
kecnry marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, image, trace_object, disp_axis=1, crossdisp_axis=(0,)):
"""
Extract the 1D spectrum using the boxcar method.

Parameters
----------
image : nddata-compatible image
image with 2-D spectral image data
trace_object : Trace
object with the trace
disp_axis : int
dispersion axis
crossdisp_axis : tuple (to support both 2D and 3D data)
cross-dispersion axis


Returns
-------
spec : `~specutils.Spectrum1D`
The extracted 1d spectrum expressed in DN and pixel units
"""
# this check only applies to FlatTrace instances
if hasattr(trace_object, 'trace_pos'):
kecnry marked this conversation as resolved.
Show resolved Hide resolved
self.center = trace_object.trace_pos
kecnry marked this conversation as resolved.
Show resolved Hide resolved
for attr in ['center', 'width']:
if getattr(self, attr) < 1:
raise ValueError(f'{attr} must be >= 1')

# images to use for extraction
wimage = _ap_weight_image(
trace_object,
self.width,
disp_axis,
crossdisp_axis,
image.shape)

# extract. Note that, for a cube, this is arbitrarily picking one of the
# spatial axis to collapse. This should be handled by the API somehow.
ext1d = np.sum(image * wimage, axis=crossdisp_axis)

# TODO: add uncertainty and mask to spectrum1D object
spec = Spectrum1D(spectral_axis=np.arange(len(ext1d)) * u.pixel,
flux=ext1d * getattr(image, 'unit', u.DN))
tepickering marked this conversation as resolved.
Show resolved Hide resolved

return spec
Loading