diff --git a/changes/9000.extract_1d.rst b/changes/9000.extract_1d.rst new file mode 100644 index 0000000000..a8453bd25d --- /dev/null +++ b/changes/9000.extract_1d.rst @@ -0,0 +1,3 @@ +Removed many unused functions and methods from SOSS extraction suite that were inaccessible from the top-level pipeline +Added guardrails for bad inputs to many SOSS ATOCA helper functions +Added unit testing suite for SOSS ATOCA extraction \ No newline at end of file diff --git a/docs/jwst/extract_1d/description.rst b/docs/jwst/extract_1d/description.rst index c6f143bebc..50c4149d23 100644 --- a/docs/jwst/extract_1d/description.rst +++ b/docs/jwst/extract_1d/description.rst @@ -392,3 +392,9 @@ the data must be given. The steps to run this correction outside the pipeline ar flux_cor = rf1d(flux, wave, channel=4) where `flux` is the extracted spectral data, and the data are from channel 4 for this example. + +Extraction for NIRISS SOSS Data +------------------------------- +For NIRISS SOSS data, the two spectral orders overlap slightly, so a specialized extraction +algorithm known as ATOCA (Algorithm to Treat Order ContAmination) is used... +Link paper diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 1331c7a7b3..fde21b35ff 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -12,10 +12,7 @@ # General imports. import numpy as np -import warnings from scipy.sparse import issparse, csr_matrix, diags -from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning -from scipy.interpolate import interp1d # Local imports. from . import atoca_utils @@ -33,201 +30,145 @@ def __init__(self, message): super().__init__(self.message) -class _BaseOverlap: - """Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep). - Used to perform an overlapping extraction of the form: - (B_T * B) * f = (data/sig)_T * B - where B is a matrix and f is an array. - The matrix multiplication B * f is the 2d model of the detector. - We want to solve for the array f. - The elements of f are labelled by 'k'. - The pixels are labeled by 'i'. - Every pixel 'i' is covered by a set of 'k' for each order - of diffraction. - The classes inheriting from this class should specify the - methods get_w which computes the 'k' associated to each pixel 'i'. - These depends of the type of interpolation used. +class ExtractionEngine: """ + Run the ATOCA algorithm (Darveau-Bernier 2022, PASP, DOI:10.1088/1538-3873/ac8a77). + + The ExtractionEngine is basically a fitter. On instantiation, it generates a model + of the detector, including a mapping between the detector pixels and the wavelength + for each spectral order, the throughput and convolution kernel, and known detector + bad pixels. This does not require any real data. + When called, it ingests data and associated errors, than + generates an output 1-D spectrum that explains the pixel brightnesses in the data + within the constraints of the model. + + The engine can also run in reverse: + The `rebuild` method generates a synthetic 2-D detector 'observation' + from a known or fitted spectrum, and the `compute_likelihood` method + compares the synthetic data to the real data to generate a likelihood. + This allows for a likelihood-based optimization of the spectrum. - # The desired data-type for computations, e.g., 'float32'. 'float64' is recommended. + This version models the pixels of the detector using an oversampled trapezoidal integration. + """ + # The desired data-type for computations. 'float64' is recommended. dtype = 'float64' def __init__(self, wave_map, trace_profile, throughput, kernels, - orders=None, global_mask=None, mask_trace_profile=None, - wave_grid=None, wave_bounds=None, n_os=2, - threshold=1e-3, c_kwargs=None): + wave_grid, mask_trace_profile, + global_mask=None, + orders=[1,2], threshold=1e-3): """ Parameters ---------- wave_map : (N_ord, N, M) list or array of 2-D arrays A list or array of the central wavelength position for each - order on the detector. It must have the same (N, M) as `data`. + order on the detector. + It has to have the same (N, M) as `data`. trace_profile : (N_ord, N, M) list or array of 2-D arrays A list or array of the spatial profile for each order - on the detector. It must have the same (N, M) as `data`. + on the detector. It has to have the same (N, M) as `data`. throughput : (N_ord [, N_k]) list of array or callable A list of functions or array of the throughput at each order. If callable, the functions depend on the wavelength. If array, projected on `wave_grid`. - kernels : array, callable or sparse matrix - Convolution kernel to be applied on the spectrum (f_k) for each order. - Can be array of the shape (N_ker, N_k_c). + kernels : callable, sparse matrix, or None. + Convolution kernel to be applied on spectrum (f_k) for each orders. Can be a callable with the form f(x, x0) where x0 is the position of the center of the kernel. In this case, it must return a 1D array (len(x)), so a kernel value - for each pairs of (x, x0). If array or callable, + for each pairs of (x, x0). If callable, it will be passed to `convolution.get_c_matrix` function and the `c_kwargs` can be passed to this function. If sparse, the shape has to be (N_k_c, N_k) and it will be used directly. N_ker is the length of the effective kernel and N_k_c is the length of the spectrum (f_k) convolved. + If None, the kernel is set to 1, i.e., do not do any convolution. + wave_grid : (N_k) array_like, required. + The grid on which f(lambda) will be projected. + mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], required. + A list or array of the pixel that need to be used for extraction, + for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. global_mask : (N, M) array_like boolean, optional - Boolean Mask of the detector pixels to mask for every extraction. + Boolean Mask of the detector pixels to mask for every extraction, e.g. bad pixels. Should not be related to a specific order (if so, use `mask_trace_profile` instead). - mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional - A list or array of the pixels that need to be used for extraction, - for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. - If not given, `threshold` will be applied on spatial profiles to define the masks. - orders : list, optional: + orders : list, optional List of orders considered. Default is orders = [1, 2] - wave_grid : (N_k) array_like, optional - The grid on which f(lambda) will be projected. - Default is a grid from `utils.get_soss_grid`. - `n_os` will be passed to this function. - wave_bounds : list or array-like (N_ord, 2), optional - Boundary wavelengths covered by each order. - Default is the wavelength covered by `wave_map`. - n_os : int, optional - Oversampling rate. If `wave_grid`is None, it will be used to - generate a grid. Default is 2. threshold : float, optional: The contribution of any order on a pixel is considered significant if its estimated spatial profile is greater than this threshold value. If it is not properly modeled (not covered by the wavelength grid), it will be masked. Default is 1e-3. - c_kwargs : list of N_ord dictionaries or dictionary, optional - Inputs keywords arguments to pass to - `convolution.get_c_matrix` function for each order. - If dictionary, the same c_kwargs will be used for each order. """ - # If no orders specified extract on orders 1 and 2. - if orders is None: - orders = [1, 2] - - ########################### - # Save basic parameters - ########################### - - # Spectral orders and number of orders. - self.data_shape = wave_map[0].shape - self.orders = orders - self.n_orders = len(orders) + # Set the attributes and ensure everything has correct dtype + self.wave_map = np.array(wave_map).astype(self.dtype) + self.trace_profile = np.array(trace_profile).astype(self.dtype) + self.mask_trace_profile = np.array(mask_trace_profile).astype(bool) self.threshold = threshold + self.data_shape = self.wave_map[0].shape - # Raise error if the number of orders is not consistent. - if self.n_orders != len(wave_map): - msg = ("The number of orders specified {} and the number of " - "wavelength maps provided {} do not match.") - log.critical(msg.format(self.n_orders, len(wave_map))) - raise ValueError(msg.format(self.n_orders, len(wave_map))) - - # Detector image. - self.data = np.full(self.data_shape, fill_value=np.nan) - - # Error map of each pixels. - self.error = np.ones(self.data_shape) - - # Set all reference file quantities to None. - self.wave_map = None - self.trace_profile = None - self.throughput = None - self.kernels = None - - # Set the wavelength map and trace_profile for each order. - self.update_wave_map(wave_map) - self.update_trace_profile(trace_profile) - - # Set the mask based on trace profiles and save - if mask_trace_profile is None: - # No mask (False everywhere) - log.warning('mask_trace_profile was not given. All detector pixels will be modeled. ' - 'It is preferable to limit the number of modeled pixels by specifying the region ' - 'of interest with mask_trace_profile.') - mask_trace_profile = np.array([np.zeros(self.data_shape, dtype=bool) for _ in orders]) - self.mask_trace_profile = mask_trace_profile - - # Generate a wavelength grid if none was provided. - if wave_grid is None: - if self.n_orders == 2: - wave_grid = atoca_utils.get_soss_grid(wave_map, trace_profile, n_os=n_os) - else: - wave_grid, _ = self.grid_from_map() - else: - # Check if the input wave_grid is sorted and strictly increasing. - is_sorted = (np.diff(wave_grid) > 0).all() - - # If not, sort it and make it unique - if not is_sorted: - log.warning('`wave_grid` is not strictly increasing. It will be sorted and unique.') - wave_grid = np.unique(wave_grid) - - # Set the wavelength grid and its size. + # Set wave_grid. Ensure it is sorted and strictly increasing. + is_sorted = (np.diff(wave_grid) > 0).all() + if not is_sorted: + log.warning('`wave_grid` is not strictly increasing. It will be sorted and made unique.') + wave_grid = np.unique(wave_grid) self.wave_grid = wave_grid.astype(self.dtype).copy() self.n_wavepoints = len(wave_grid) - # Set the throughput for each order. - self.update_throughput(throughput) - - ################################### - # Build detector mask - ################################### + # Get wavelengths at the boundaries of each pixel for all orders + wave_p, wave_m = [], [] + for wave in self.wave_map: + lp, lm = atoca_utils.get_wave_p_or_m(wave) + wave_p.append(lp) + wave_m.append(lm) + self.wave_p = np.array(wave_p, dtype=self.dtype) + self.wave_m = np.array(wave_m, dtype=self.dtype) + + # Set orders and ensure that the number of orders is consistent with wave_map length + self.orders = orders + self.n_orders = len(self.orders) + if self.n_orders != len(self.wave_map): + msg = ("The number of orders specified ({}) and the number of " + "wavelength maps provided ({}) do not match.") + log.critical(msg.format(self.n_orders, len(self.wave_map))) + raise ValueError(msg.format(self.n_orders, len(self.wave_map))) - # Assign a first estimate of i_bounds to be able to compute mask. - self.i_bounds = [[0, len(wave_grid)] for _ in range(self.n_orders)] + # Set a first estimate of i_bounds to estimate mask + self.i_bounds = [[0, len(self.wave_grid)] for _ in range(self.n_orders)] - # First estimate of a global mask and masks for each orders + # Estimate a global mask and masks for each orders self.mask, self.mask_ord = self._get_masks(global_mask) # Ensure there are adequate good pixels left in each order good_pixels_in_order = np.sum(np.sum(~self.mask_ord, axis=-1), axis=-1) min_good_pixels = 25 # hard-code to qualitatively reasonable value if np.any(good_pixels_in_order < min_good_pixels): - raise MaskOverlapError('At least one order has no valid pixels (mask_trace_profile and mask_wave do not overlap)') + msg = (f'At least one order has less than {min_good_pixels} valid pixels. ' + '(mask_trace_profile and mask_wave have insufficient overlap') + raise MaskOverlapError(msg) + + # Update i_bounds based on masked wavelengths + self.i_bounds = self._get_i_bnds() - # Correct i_bounds if it was not specified - self.i_bounds = self._get_i_bnds(wave_bounds) + # if throughput is given as callable, turn it into an array + # with shape (n_ord, wave_grid.size) + self.update_throughput(throughput) # Re-build global mask and masks for each orders self.mask, self.mask_ord = self._get_masks(global_mask) - - # Save mask here as the general mask, - # since `mask` attribute can be changed. + # Save mask here as the general mask, since `mask` attribute can be changed. self.general_mask = self.mask.copy() - #################################### - # Build convolution matrix - #################################### - self.update_kernels(kernels, c_kwargs) - - ############################# - # Compute integration weights - ############################# - # The weights depend on the integration method used solve - # the integral of the flux over a pixel and are encoded - # in the class method `get_w()`. + # turn kernels into sparse matrix + self.kernels = self._create_kernels(kernels) + + # Compute integration weights. see method self.get_w() for details. self.weights, self.weights_k_idx = self.compute_weights() - ######################### - # Save remaining inputs - ######################### - # Init the pixel mapping (b_n) matrices. Matrices that transforms the 1D spectrum to a the image pixels. self.pixel_mapping = [None for _ in range(self.n_orders)] - self.i_grid = None self.tikho_mat = None self.w_t_wave_c = None - return def get_attributes(self, *args, i_order=None): """Return list of attributes @@ -256,117 +197,61 @@ def get_attributes(self, *args, i_order=None): return out - def update_wave_map(self, wave_map): - """Update internal wave_map - Parameters - ---------- - wave_map : array[float] - Wavelength maps for each order - - Returns - ------- - None - """ - dtype = self.dtype - self.wave_map = [wave_n.astype(dtype).copy() for wave_n in wave_map] - - return - - def update_trace_profile(self, trace_profile): - """Update internal trace_profiles - Parameters - ---------- - trace_profile : array[float] - Trace profiles for each order - - Returns - ------- - None - """ - dtype = self.dtype - - # Update the trace_profile profile. - self.trace_profile = [trace_profile_n.astype(dtype).copy() for trace_profile_n in trace_profile] - - return def update_throughput(self, throughput): """Update internal throughput values + Parameters ---------- throughput : array[float] or callable Throughput values for each order, given either as an array or as a callable function with self.wave_grid as input. - - Returns - ------- - None """ - - # Update the throughput values. throughput_new = [] for throughput_n in throughput: # Loop over orders. if callable(throughput_n): + throughput_n = throughput_n(self.wave_grid) - # Throughput was given as a callable function. - throughput_new.append(throughput_n(self.wave_grid)) - - elif throughput_n.shape == self.wave_grid.shape: - - # Throughput was given as an array. - throughput_new.append(throughput_n) - - else: - msg = 'Throughputs must be given as callable or arrays matching the extraction grid.' + msg = 'Throughputs must be given as callable or arrays matching the extraction grid.' + if not isinstance(throughput_n, np.ndarray): + log.critical(msg) + raise ValueError(msg) + if throughput_n.shape != self.wave_grid.shape: log.critical(msg) raise ValueError(msg) - # Set the attribute to the new values. - self.throughput = throughput_new + throughput_new.append(throughput_n) - return + self.throughput = np.array(throughput_new, dtype=self.dtype) + + + def _create_kernels(self, kernels): + """Make sparse matrix from input kernels - def update_kernels(self, kernels, c_kwargs): - """Update internal kernels Parameters ---------- - kernels : array, callable or sparse matrix + kernels : callable, sparse matrix, or None Convolution kernel to be applied on the spectrum (f_k) for each order. - c_kwargs : list of N_ord dictionaries or dictionary, optional - Inputs keywords arguments to pass to - `convolution.get_c_matrix` function for each order. - If dictionary, the same c_kwargs will be used for each order. - - Returns - ------- - None + If None, kernel is set to 1, i.e., do not do any convolution. """ - # Check the c_kwargs inputs - # If not given - if c_kwargs is None: - # Then use the kernels min_value attribute. - # It is a way to make sure that the full kernel - # is used. - c_kwargs = [] - for ker in kernels: - # If the min_value not specified, then - # simply take the get_c_matrix defaults - try: - kwargs_ker = {'thresh': ker.min_value} - except AttributeError: - kwargs_ker = dict() - c_kwargs.append(kwargs_ker) - - # ...or same for each orders if only a dictionary was given - elif isinstance(c_kwargs, dict): - c_kwargs = [c_kwargs for _ in kernels] + # Take thresh to be the kernels min_value attribute. + # It is a way to make sure that the full kernel is used. + c_kwargs = [] + for ker in kernels: + try: + kwargs_ker = {'thresh': ker.min_value} + except AttributeError: + # take the get_c_matrix defaults + kwargs_ker = {} + c_kwargs.append(kwargs_ker) # Define convolution sparse matrix. kernels_new = [] for i_order, kernel_n in enumerate(kernels): - + if kernel_n is None: + kernel_n = np.array([1.0]) if not issparse(kernel_n): kernel_n = atoca_utils.get_c_matrix(kernel_n, self.wave_grid, i_bounds=self.i_bounds[i_order], @@ -374,39 +259,12 @@ def update_kernels(self, kernels, c_kwargs): kernels_new.append(kernel_n) - self.kernels = kernels_new - - return - - def get_mask_wave(self, i_order): - """Generate mask bounded by limits of wavelength grid - Parameters - ---------- - i_order : int - Order to select the wave_map on which a mask - will be generated - - Returns - ------- - array[bool] - A mask with True where wave_map is outside the bounds - of wave_grid - """ - - wave = self.wave_map[i_order] - imin, imax = self.i_bounds[i_order] - wave_min = self.wave_grid[imin] - wave_max = self.wave_grid[imax - 1] - - mask = (wave <= wave_min) | (wave >= wave_max) + return kernels_new - return mask def _get_masks(self, global_mask): """Compute a general mask on the detector and for each order. - Depends on the spatial profile, the wavelength grid - and the user defined mask (optional). These are all specified - when initializing the object. + Depends on the trace profile and the wavelength grid. Parameters ---------- @@ -423,16 +281,13 @@ def _get_masks(self, global_mask): # Get needed attributes args = ('threshold', 'n_orders', 'mask_trace_profile', 'trace_profile') - needed_attr = self.get_attributes(*args) - threshold, n_orders, mask_trace_profile, trace_profile = needed_attr - - # Convert list to array (easier for coding) - mask_trace_profile = np.array(mask_trace_profile) + threshold, n_orders, mask_trace_profile, trace_profile = self.get_attributes(*args) # Mask pixels not covered by the wavelength grid. mask_wave = np.array([self.get_mask_wave(i_order) for i_order in range(n_orders)]) - # Apply user defined mask. + # combine trace profile mask with wavelength cutoff mask + # and apply detector bad pixel mask if specified if global_mask is None: mask_ord = np.any([mask_trace_profile, mask_wave], axis=0) else: @@ -455,116 +310,47 @@ def _get_masks(self, global_mask): return general_mask, mask_ord - def update_mask(self, mask): - """Update `mask` attribute by combining the `general_mask` - attribute with the input `mask`. Every time the mask is - changed, the integration weights need to be recomputed - since the pixels change. - - Parameters - ---------- - mask : array[bool] - New mask to be combined with internal general_mask and - saved in self.mask. - - Returns - ------- - None - """ - - # Get general mask - general_mask = self.general_mask - - # Complete with the input mask - new_mask = (general_mask | mask) - - # Update attribute - self.mask = new_mask - - # Correct i_bounds if it was not specified - # self.update_i_bnds() - - # Re-compute weights - self.weights, self.weights_k_idx = self.compute_weights() - return + def _get_i_bnds(self): + """Define wavelength boundaries for each order using the order's mask + and the wavelength map. - def _get_i_bnds(self, wave_bounds=None): - """Define wavelength boundaries for each order using the order's mask. - Parameters - ---------- - wave_bounds : list[float], optional - Minimum and maximum values of masked wavelength map. If not given, - calculated from internal wave_map and mask_ord for each order. Returns ------- list[float] Wavelength boundaries for each order """ - wave_grid = self.wave_grid - i_bounds = self.i_bounds + # Figure out boundary wavelengths + wave_bounds = [] + for i in range(self.n_orders): + wave = self.wave_map[i][~self.mask_ord[i]] + wave_bounds.append([wave.min(), wave.max()]) - # Check if wave_bounds given - if wave_bounds is None: - wave_bounds = [] - for i in range(self.n_orders): - wave = self.wave_map[i][~self.mask_ord[i]] - wave_bounds.append([wave.min(), wave.max()]) - - # What we need is the boundary position - # on the wavelength grid. + # Determine the boundary position on the wavelength grid. i_bnds_new = [] - for bounds, i_bnds in zip(wave_bounds, i_bounds): + for bounds, i_bnds in zip(wave_bounds, self.i_bounds): - a = np.min(np.where(wave_grid >= bounds[0])[0]) - b = np.max(np.where(wave_grid <= bounds[1])[0]) + 1 + a = np.min(np.where(self.wave_grid >= bounds[0])[0]) + b = np.max(np.where(self.wave_grid <= bounds[1])[0]) + 1 # Take the most restrictive bound a = np.maximum(a, i_bnds[0]) b = np.minimum(b, i_bnds[1]) - - # Keep value - i_bnds_new.append([a, b]) + i_bnds_new.append([int(a), int(b)]) return i_bnds_new - def update_i_bnds(self): - """Update the grid limits for the extraction. - Needs to be done after modification of the mask - """ - - # Get old and new boundaries. - i_bnds_old = self.i_bounds - i_bnds_new = self._get_i_bnds() - - for i_order in range(self.n_orders): - - # Take most restrictive lower bound. - low_bnds = [i_bnds_new[i_order][0], i_bnds_old[i_order][0]] - i_bnds_new[i_order][0] = np.max(low_bnds) - - # Take most restrictive upper bound. - up_bnds = [i_bnds_new[i_order][1], i_bnds_old[i_order][1]] - i_bnds_new[i_order][1] = np.min(up_bnds) - - # Update attribute. - self.i_bounds = i_bnds_new - - return def wave_grid_c(self, i_order): - """Return wave_grid for the convolved flux at a given order. + """Return wave_grid for a given order constrained according to the i_bounds + of that order. """ index = slice(*self.i_bounds[i_order]) return self.wave_grid[index] - def get_w(self, i_order): - """Dummy method to init this class""" - - return np.array([]), np.array([]) def compute_weights(self): """ @@ -582,9 +368,10 @@ def compute_weights(self): # Init lists weights, weights_k_idx = [], [] - for i_order in range(self.n_orders): # For each orders + for i_order in range(self.n_orders): - weights_n, k_idx_n = self.get_w(i_order) # Compute weights + # Compute weights + weights_n, k_idx_n = self.get_w(i_order) # Convert to sparse matrix # First get the dimension of the convolved grid @@ -597,17 +384,15 @@ def compute_weights(self): return weights, weights_k_idx def _set_w_t_wave_c(self, i_order, product): - """Save the matrix product of the weighs (w), the throughput (t), + """Save the matrix product of the weights (w), the throughput (t), the wavelength (lam) and the convolution matrix for faster computation. """ if self.w_t_wave_c is None: self.w_t_wave_c = [[] for _ in range(self.n_orders)] - # Assign value self.w_t_wave_c[i_order] = product.copy() - return def grid_from_map(self, i_order=0): """Return the wavelength grid and the columns associated @@ -618,73 +403,19 @@ def grid_from_map(self, i_order=0): wave_map, trace_profile = self.get_attributes(*attrs, i_order=i_order) wave_grid, icol = atoca_utils._grid_from_map(wave_map, trace_profile) - wave_grid = wave_grid.astype(self.dtype) - return wave_grid, icol - def estimate_noise(self, i_order=0, data=None, error=None, mask=None): - """Relative noise estimate over columns. - Parameters - ---------- - i_order : int, optional - index of diffraction order. Default is 0 - data : 2d array, optional - map of the detector image - Default is `self.data`. - error : 2d array, optional - map of the estimate of the detector noise. - Default is `self.sig` - mask : 2d array, optional - Bool map of the masked pixels for order `i_order`. - Default is `self.mask_ord[i_order]` - - Returns - ------ - wave_grid : array[float] - The wavelength grid. - noise : array[float] - The associated noise array. + def get_pixel_mapping(self, i_order, error=None, quick=False): """ - - # Use object attributes if not given - if data is None: - data = self.data - - if error is None: - error = self.error - - if mask is None: - mask = self.mask_ord[i_order] - - # Compute noise estimate only on the trace (mask the rest) - noise = np.ma.array(error, mask=mask) - - # RMS over columns - noise = np.sqrt((noise**2).sum(axis=0)) - - # Relative - noise /= np.ma.array(data, mask=mask).sum(axis=0) - - # Convert to array with nans - noise = noise.filled(fill_value=np.nan) - - # Get associated wavelengths - wave_grid, i_col = self.grid_from_map(i_order) - - # Return sorted according to wavelengths - return wave_grid, noise[i_col] - - def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): - """Compute the matrix `b_n = (P/sig).w.T.lambda.c_n` , + Compute the matrix `b_n = (P/sig).w.T.lambda.c_n` , where `P` is the spatial profile matrix (diag), `w` is the integrations weights matrix, `T` is the throughput matrix (diag), `lambda` is the convolved wavelength grid matrix (diag), `c_n` is the convolution kernel. - The model of the detector at order n (`model_n`) - is given by the system: + The model of the detector at order n (`model_n`) is given by the system: model_n = b_n.c_n.f , where f is the incoming flux projected on the wavelength grid. This methods updates the `b_n_list` attribute. @@ -693,15 +424,10 @@ def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): ---------- i_order: integer Label of the order (depending on the initiation of the object). - same: bool, optional - Do not recompute b_n. Take the last b_n computed. - Useful to speed up code. Default is False. - error: bool or (N, M) array_like, optional - If 2-d array, `sig` is the new error estimation map. - It is the same shape as `sig` initiation input. If bool, - whether to apply sigma or not. The method will return - b_n/sigma if True or array_like and b_n if False. If True, - the default object attribute `sig` will be used. + error: (N, M) array_like or None, optional. + Estimate of the error on each pixel. Same shape as `data`. + If None, the error is set to 1, which means the method will return + b_n instead of b_n/sigma. Default is None. quick: bool, optional If True, only perform one matrix multiplication instead of the whole system: (P/sig).(w.T.lambda.c_n) @@ -711,100 +437,72 @@ def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): array[float] Sparse matrix of b_n coefficients """ + if (quick) and (self.w_t_wave_c is None): + msg = "Attribute w_t_wave_c of ExtractionEngine must exist if quick=True" + raise AttributeError(msg) - # Force to compute if b_n never computed. - if self.pixel_mapping[i_order] is None: - same = False - - # Take the last b_n computed if nothing changes - if same: - pixel_mapping = self.pixel_mapping[i_order] - - else: - # Special treatment for error map - # Can be bool or array. - if error is False: - # Sigma will have no effect - error = np.ones(self.data_shape) - else: - if error is not True: - # Sigma must be an array so - # update object attribute - self.error = error.copy() - - # Take sigma from object - error = self.error + # Special treatment for error map + # Can be bool or array. + if error is None: + # Sigma will have no effect + error = np.ones(self.data_shape) - # Get needed attributes ... - attrs = ['wave_grid', 'mask'] - wave_grid, mask = self.get_attributes(*attrs) + # Get needed attributes ... + attrs = ['wave_grid', 'mask'] + wave_grid, mask = self.get_attributes(*attrs) - # ... order dependent attributes - attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] - trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) + # ... order dependent attributes + attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] + trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) - # Keep only valid pixels (P and sig are still 2-D) - # And apply directly 1/sig here (quicker) - trace_profile_n = trace_profile_n[~mask] / error[~mask] + # Keep only valid pixels (P and sig are still 2-D) + # And apply directly 1/sig here (quicker) + trace_profile_n = trace_profile_n[~mask] / error[~mask] - # Compute b_n - # Quick mode if only `p_n` or `sig` has changed - if quick: - # Get pre-computed (right) part of the equation - right = self.w_t_wave_c[i_order] + # Compute b_n + # Quick mode if only `p_n` or `sig` has changed + if quick: + # Get pre-computed (right) part of the equation + right = self.w_t_wave_c[i_order] - # Apply new p_n - pixel_mapping = diags(trace_profile_n).dot(right) + # Apply new p_n + pixel_mapping = diags(trace_profile_n).dot(right) - else: - # First (T * lam) for the convolve axis (n_k_c) - product = (throughput_n * wave_grid)[slice(*i_bnds)] + else: + # First (T * lam) for the convolve axis (n_k_c) + product = (throughput_n * wave_grid)[slice(*i_bnds)] - # then convolution - product = diags(product).dot(kernel_n) + # then convolution + product = diags(product).dot(kernel_n) - # then weights - product = weights_n.dot(product) + # then weights + product = weights_n.dot(product) - # Save this product for quick mode - self._set_w_t_wave_c(i_order, product) + # Save this product for quick mode + self._set_w_t_wave_c(i_order, product) - # Then spatial profile - pixel_mapping = diags(trace_profile_n).dot(product) + # Then spatial profile + pixel_mapping = diags(trace_profile_n).dot(product) - # Save new pixel mapping matrix. - self.pixel_mapping[i_order] = pixel_mapping + # Save new pixel mapping matrix. + self.pixel_mapping[i_order] = pixel_mapping return pixel_mapping - def build_sys(self, data=None, error=True, mask=None, trace_profile=None, throughput=None): - """Build linear system arising from the logL maximisation. + + def build_sys(self, data, error): + """ + Build linear system arising from the logL maximisation. TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. There will be only one matrix multiplication: (P/sig).(w.T.lambda.c_n). + Parameters ---------- - data : (N, M) array_like, optional + data : (N, M) array_like A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error : bool or (N, M) array_like, optional + error : (N, M) array_like Estimate of the error on each pixel. - If 2-d array, `sig` is the new error estimation map. - It is the same shape as `sig` initiation input. If bool, - whether to apply sigma or not. The method will return - b_n/sigma if True or array_like and b_n if False. If True, - the default object attribute `sig` will be use. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` Returns ------ @@ -812,7 +510,7 @@ def build_sys(self, data=None, error=True, mask=None, trace_profile=None, throug """ # Get the detector model - b_matrix, data = self.get_detector_model(data, error, mask, trace_profile, throughput) + b_matrix, data = self.get_detector_model(data, error) # (B_T * B) * f = (data/sig)_T * B # (matrix ) * f = result @@ -821,34 +519,17 @@ def build_sys(self, data=None, error=True, mask=None, trace_profile=None, throug return matrix, result.toarray().squeeze() - def get_detector_model(self, data=None, error=True, mask=None, trace_profile=None, throughput=None): - """Get the linear model of the detector pixel, B.dot(flux) = pixels - TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. - There will be only one matrix multiplication: - (P/sig).(w.T.lambda.c_n). + + def get_detector_model(self, data, error): + """ + Get the linear model of the detector pixel, B.dot(flux) = pixels + Parameters ---------- - data : (N, M) array_like, optional + data : (N, M) array_like A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error: bool or (N, M) array_like, optional + error: (N, M) array_like Estimate of the error on each pixel. - If 2-d array, `sig` is the new error estimation map. - It is the same shape as `sig` initiation input. If bool, - whether to apply sigma or not. The method will return - b_n/sigma if True or array_like and b_n if False. If True, - the default object attribute `sig` will be use. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` Returns ------ @@ -856,127 +537,47 @@ def get_detector_model(self, data=None, error=True, mask=None, trace_profile=Non From the linear equation B.dot(flux) = pix_array """ - # Check if inputs are suited for quick mode; - # Quick mode if `t_list` is not specified. - quick = (throughput is None) - - # and if mask doesn't change - quick &= (mask is None) - quick &= (self.w_t_wave_c is not None) # Pre-computed - - # Use data from object as default - if data is None: - data = self.data - else: - # Update data - self.data = data - - # Update mask if given - if mask is not None: - self.update_mask(mask) - - # Take (updated) mask from object - mask = self.mask - - # Get some dimensions infos - n_wavepoints, n_orders = self.n_wavepoints, self.n_orders - - # Update trace_profile maps and throughput values. - if trace_profile is not None: - self.update_trace_profile(trace_profile) - - if throughput is not None: - self.update_throughput(throughput) - - # Calculations + # Check if `w_t_wave_c` is pre-computed + quick = (self.w_t_wave_c is not None) # Build matrix B # Initiate with empty matrix - n_i = (~mask).sum() # n good pixels - b_matrix = csr_matrix((n_i, n_wavepoints)) + n_i = (~self.mask).sum() # n good pixels + b_matrix = csr_matrix((n_i, self.n_wavepoints)) # Sum over orders - for i_order in range(n_orders): + for i_order in range(self.n_orders): # Get sparse pixel mapping matrix. - b_matrix += self.get_pixel_mapping(i_order, error=error, quick=quick) + b_matrix += self.get_pixel_mapping(i_order, error, quick=quick) # Build detector pixels' array - # Fisrt get `error` which have been update` - # when calling `get_pixel_mapping` - error = self.error - # Take only valid pixels and apply `error` on data - data = data[~mask] / error[~mask] + data = data[~self.mask] / error[~self.mask] return b_matrix, csr_matrix(data) - def set_tikho_matrix(self, t_mat=None, t_mat_func=None, fargs=None, fkwargs=None): - """Set the tikhonov matrix attribute. - The matrix can be directly specified as an input, or - it can be built using `t_mat_func` - Parameters - ---------- - t_mat : matrix-like, optional - Tikhonov regularization matrix. scipy.sparse matrix - are recommended. - t_mat_func : callable, optional - Function used to generate `t_mat` if not specified. - Will take `fargs` and `fkwargs`as input. - Use the `atoca_utils.get_tikho_matrix` as default. - fargs : tuple, optional - Arguments passed to `t_mat_func`. Default is `(self.wave_grid, )`. - fkwargs : dict, optional - Keyword arguments passed to `t_mat_func`. Default is - `{'n_derivative': 1, 'd_grid': True, 'estimate': None, 'pwr_law': 0}` - - Returns - ------- - None + @property + def tikho_mat(self): """ - - # Generate the matrix with the function - if t_mat is None: - - # Default function if not specified - if t_mat_func is None: - # Use the `atoca_utils.get_tikho_matrix` - # The default arguments will return the 1rst derivative - # as the tikhonov matrix - t_mat_func = atoca_utils.get_tikho_matrix - - # Default args - if fargs is None: - # The argument for `atoca_utils.get_tikho_matrix` is the wavelength grid - fargs = (self.wave_grid, ) - if fkwargs is None: - # The kwargs for `atoca_utils.get_tikho_matrix` are - # n_derivative = 1, d_grid = True, estimate = None, pwr_law = 0 - fkwargs = {'n_derivative': 1, 'd_grid': True, 'estimate': None, 'pwr_law': 0} - - # Call function - t_mat = t_mat_func(*fargs, **fkwargs) - - # Set attribute - self.tikho_mat = t_mat - - return - - def get_tikho_matrix(self, **kwargs): - """Return the Tikhonov matrix. - Generate it with `set_tikho_matrix` method - if not defined yet. If so, all arguments are passed - to `set_tikho_matrix`. The result is saved as an attribute. + Return the Tikhonov matrix. """ + if self._tikho_mat is not None: + return self._tikho_mat + + self._tikho_mat = atoca_utils.finite_first_d(self.wave_grid) + return self._tikho_mat + - if self.tikho_mat is None: - self.set_tikho_matrix(**kwargs) + @tikho_mat.setter + def tikho_mat(self, t_mat): + self._tikho_mat = t_mat - return self.tikho_mat def estimate_tikho_factors(self, flux_estimate): - """Estimate an initial guess of the Tikhonov factor. The output factor will + """ + Estimate an initial guess of the Tikhonov factor. The output factor will be used to find the best Tikhonov factor. The flux_estimate is used to generate a factor_guess. The user should construct a grid with this output in log space, e.g. np.logspace(np.log10(flux_estimate)-4, np.log10(flux_estimate)+4, 9). @@ -989,8 +590,8 @@ def estimate_tikho_factors(self, flux_estimate): Returns ------- - array[float] - Grid of Tikhonov factors. + float + Estimated Tikhonov factor. """ # Get some values from the object mask, wave_grid = self.get_attributes('mask', 'wave_grid') @@ -1001,11 +602,8 @@ def estimate_tikho_factors(self, flux_estimate): # Project the estimate on the wavelength grid estimate_on_grid = flux_estimate(wave_grid) - # Get the tikhonov matrix - tikho_matrix = self.get_tikho_matrix() - # Estimate the norm-2 of the regularization term - reg_estimate = tikho_matrix.dot(estimate_on_grid) + reg_estimate = self.tikho_mat.dot(estimate_on_grid) reg_estimate = np.nansum(np.array(reg_estimate) ** 2) # Estimate of the factor @@ -1015,8 +613,8 @@ def estimate_tikho_factors(self, flux_estimate): return factor_guess - def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, - error=None, mask=None, trace_profile=None, throughput=None): + + def get_tikho_tests(self, factors, data, error): """ Test different factors for Tikhonov regularization. @@ -1024,45 +622,21 @@ def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, ---------- factors : 1D list or array-like Factors to be tested. - tikho : Tikhonov object, optional - Tikhonov regularization object (see regularization.Tikhonov). - If not given, an object will be initiated using the linear system - from `build_sys` method and kwargs will be passed. - tikho_kwargs : - passed to init Tikhonov object. Possible options - are `t_mat` and `grid` - data : (N, M) array_like, optional + data : (N, M) array_like A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error : (N, M) array_like, optional - Estimate of the error on each pixel` - Same shape as `data`. - Default is the object attribute `sig`. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` + error : (N, M) array_like + Estimate of the error on each pixel. Same shape as `data`. Returns ------ - dictionary of the tests results + tests : dict + dictionary of the test results """ # Build the system to solve - b_matrix, pix_array = self.get_detector_model(data, error, mask, trace_profile, throughput) + b_matrix, pix_array = self.get_detector_model(data, error) - if tikho is None: - t_mat = self.get_tikho_matrix() - if tikho_kwargs is None: - tikho_kwargs = {} - tikho = atoca_utils.Tikhonov(b_matrix, pix_array, t_mat, **tikho_kwargs) + tikho = atoca_utils.Tikhonov(b_matrix, pix_array, self.tikho_mat) # Test all factors tests = tikho.test_factors(factors) @@ -1070,12 +644,10 @@ def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, # Save also grid tests["grid"] = self.wave_grid - # Save as attribute - self.tikho_tests = tests - return tests - def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): + + def best_tikho_factor(self, tests, fit_mode): """ Compute the best scale factor for Tikhonov regularization. It is determined by taking the factor giving the lowest reduced chi2 on @@ -1084,52 +656,19 @@ def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): Parameters ---------- - tests : dictionary, optional + tests : dictionary Results of Tikhonov extraction tests for different factors. Must have the keys "factors" and "-logl". - If not specified, the tests from self.tikho.tests are used. - fit_mode : string, optional + fit_mode : string Which mode is used to find the best Tikhonov factor. Options are 'all', 'curvature', 'chi2', 'd_chi2'. If 'all' is chosen, the best of the three other options will be selected. - mode_kwargs : dictionary-like - Dictionary of keyword arguments to be passed to TikhoTests.best_tikho_factor(). - Example: mode_kwargs = {'curvature': curvature_kwargs}. - Here, curvature_kwargs is also a dictionary. Returns ------- best_fac : float The best Tikhonov factor. - best_mode : str - The mode used to determine the best factor. - results : dict - A dictionary holding the factors computed for each mode requested. """ - - # Use pre-run tests if not specified - if tests is None: - tests = self.tikho_tests - - # TODO Find a way to identify when the solution becomes unstable - # and do nnot use these in the search for the best tikhonov factor. - # The follwing commented bloc was an attemp to do it, but problems - # occur if the chi2 reaches a maximum at large factors. -# # Remove all bad factors that are most likely unstable -# min_factor = tests.best_tikho_factor(mode='d_chi2', thresh=1e-8) -# idx_to_keep = min_factor <= tests['factors'] -# print(idx_to_keep) -# # # Keep at least the max factor if None are found -# # if not idx_to_keep.any(): -# # idx_max = np.argmax(tests['factors']) -# # idx_to_keep[idx_max] = True -# # Make new tests with remaining factors -# new_tests = dict() -# for key in tests: -# if key != 'grid': -# new_tests[key] = tests[key][idx_to_keep] -# tests = atoca_utils.TikhoTests(new_tests) - # Modes to be tested if fit_mode == 'all': # Test all modes @@ -1138,30 +677,17 @@ def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): # Single mode list_mode = [fit_mode] - # Init the mode_kwargs if None were given - if mode_kwargs is None: - mode_kwargs = dict() - - # Fill the missing value in mode_kwargs - for mode in list_mode: - try: - # Try the mode - mode_kwargs[mode] - except KeyError: - # Init with empty dictionary if it was not given - mode_kwargs[mode] = dict() - # Evaluate best factor with different methods - results = dict() + results = {} for mode in list_mode: - best_fac = tests.best_tikho_factor(mode=mode, **mode_kwargs[mode]) + best_fac = tests.best_factor(mode=mode) results[mode] = best_fac if fit_mode == 'all': # Choose the best factor. - # In a well behave case, the results should be ordered as 'chi2', 'd_chi2', 'curvature' + # In a well-behaved case, the results should be ordered as 'chi2', 'd_chi2', 'curvature' # and 'd_chi2' will be the best criterion determine the best factor. - # 'chi2' usually overfitting the solution and 'curvature' may oversmooth the solution + # 'chi2' usually overfits the solution and 'curvature' may oversmooth the solution if results['curvature'] <= results['chi2'] or results['d_chi2'] <= results['chi2']: # In this case, 'chi2' is likely to not overfit the solution, so must be favored best_mode = 'chi2' @@ -1179,25 +705,20 @@ def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): # Get the factor of the chosen mode best_fac = results[best_mode] - log.debug(f'Mode chosen to find regularization factor is {best_mode}') - return best_fac, best_mode, results + return best_fac + + + def rebuild(self, spectrum, fill_value=0.0): + """ + Build current model image of the detector. - def rebuild(self, spectrum=None, i_orders=None, same=False, fill_value=0.0): - """Build current model image of the detector. Parameters ---------- - spectrum : callable or array-like, optional + spectrum : callable or array-like flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - If not provided, will find via self.__call__(). - i_orders : list[int], optional - Indices of orders to model. Default is all available orders. - same : bool, optional - If True, do not recompute the pixel_mapping matrix (b_n) - and instead use the most recent pixel_mapping to speed up the computation. - Default is False. fill_value : float or np.nan, optional Pixel value where the detector is masked. Default is 0.0. @@ -1206,75 +727,55 @@ def rebuild(self, spectrum=None, i_orders=None, same=False, fill_value=0.0): array[float] The modeled detector image. """ - - # If no spectrum given compute it. - if spectrum is None: - spectrum = self.__call__() - # If flux is callable, evaluate on the wavelength grid. if callable(spectrum): spectrum = spectrum(self.wave_grid) - # Iterate over all orders by default. - if i_orders is None: - i_orders = range(self.n_orders) - - # Get required class attribute. - mask = self.mask + # Iterate over all orders + i_orders = range(self.n_orders) # Evaluate the detector model. model = np.zeros(self.data_shape) for i_order in i_orders: # Compute the pixel mapping matrix (b_n) for the current order. - pixel_mapping = self.get_pixel_mapping(i_order, error=False, same=same) + pixel_mapping = self.get_pixel_mapping(i_order, error=None) # Evaluate the model of the current order. - model[~mask] += pixel_mapping.dot(spectrum) + model[~self.mask] += pixel_mapping.dot(spectrum) # Assign masked values - model[mask] = fill_value - + model[self.mask] = fill_value return model - def compute_likelihood(self, spectrum=None, same=False): + + def compute_likelihood(self, spectrum, data, error): """Return the log likelihood associated with a particular spectrum. Parameters ---------- - spectrum : array[float] or callable, optional + spectrum : array[float] or callable Flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - If not given it will be computed by calling self.__call__(). - same : bool, optional - If True, do not recompute the pixel_mapping matrix (b_n) - and instead use the most recent pixel_mapping to speed up the computation. - Default is False. + data : (N, M) array_like + A 2-D array of real values representing the detector image. + error : (N, M) array_like + Estimate of the error on each pixel. + Same shape as `data`. Returns ------- array[float] The log-likelihood of the spectrum. """ - - # If no spectrum given compute it. - if spectrum is None: - spectrum = self.__call__() - # Evaluate the model image for the spectrum. - model = self.rebuild(spectrum, same=same) - - # Get data and error attributes. - data = self.data - error = self.error - mask = self.mask + model = self.rebuild(spectrum) # Compute the log-likelihood for the spectrum. with np.errstate(divide='ignore'): logl = (model - data) / error - logl = -np.nansum((logl[~mask])**2) + return -np.nansum((logl[~self.mask])**2) - return logl @staticmethod def _solve(matrix, result): @@ -1291,16 +792,8 @@ def _solve(matrix, result): # Only solve for valid indices, i.e. wavelengths that are # covered by the pixels on the detector. # It will be a singular matrix otherwise. - with warnings.catch_warnings(): - warnings.filterwarnings(action='error', category=MatrixRankWarning) - try: - sln[idx] = spsolve(matrix[idx, :][:, idx], result[idx]) - except MatrixRankWarning: - # on rare occasions spsolve's approximation of the matrix is not appropriate - # and fails on good input data. revert to different solver - log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.') - sln[idx] = lsqr(matrix[idx, :][:, idx], result[idx])[0] - + matrix = matrix[idx, :][:, idx] + sln[idx] = atoca_utils.try_solve_two_methods(matrix, result[idx]) return sln @staticmethod @@ -1312,66 +805,55 @@ def _solve_tikho(matrix, result, t_mat, **kwargs): return tikho.solve(**kwargs) - def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): + def __call__(self, data, error, tikhonov=False, factor=None): """ Extract underlying flux on the detector. - All parameters are passed to `build_sys` method. + + Performs an overlapping extraction of the form: + (B_T * B) * f = (data/sig)_T * B + where B is a matrix and f is an array. + The matrix multiplication B * f is the 2d model of the detector. + We want to solve for the array f. + The elements of f are labelled by 'k'. + The pixels are labeled by 'i'. + Every pixel 'i' is covered by a set of 'k' for each order + of diffraction. + TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. There will be only one matrix multiplication: (P/sig).(w.T.lambda.c_n). Parameters ---------- + data : (N, M) array_like + A 2-D array of real values representing the detector image. + error : (N, M) array_like + Estimate of the error on each pixel` + Same shape as `data`. tikhonov : bool, optional Whether to use Tikhonov extraction Default is False. - tikho_kwargs : dictionary or None, optional - Arguments passed to `tikho_solve`. factor : the Tikhonov factor to use if tikhonov is True - data : (N, M) array_like, optional - A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error : (N, M) array_like, optional - Estimate of the error on each pixel` - Same shape as `data`. - Default is the object attribute `sig`. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` Returns ----- spectrum (f_k): solution of the linear system """ - # Solve with the specified solver. if tikhonov: - # Build the system to solve - b_matrix, pix_array = self.get_detector_model(**kwargs) - if factor is None: msg = "Please specify tikhonov `factor`." log.critical(msg) raise ValueError(msg) - t_mat = self.get_tikho_matrix() - - if tikho_kwargs is None: - tikho_kwargs = {} + # Build the system to solve + b_matrix, pix_array = self.get_detector_model(data, error) - spectrum = self._solve_tikho(b_matrix, pix_array, t_mat, factor=factor, **tikho_kwargs) + spectrum = self._solve_tikho(b_matrix, pix_array, self.tikho_mat, factor=factor) else: # Build the system to solve - matrix, result = self.build_sys(**kwargs) + matrix, result = self.build_sys(data, error) # Only solve for valid range `i_grid` (on the detector). # It will be a singular matrix otherwise. @@ -1379,200 +861,8 @@ def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): return spectrum - def bin_to_pixel(self, i_order=0, grid_pix=None, grid_f_k=None, convolved_spectrum=None, - spectrum=None, bounds_error=False, throughput=None, **kwargs): - """Integrate the convolved_spectrum (f_k_c) over a pixel grid using the trapezoidal rule. - The convolved spectrum (f_k_c) is interpolated using scipy.interpolate.interp1d and the - kwargs and bounds_error are passed to interp1d. - i_order : int, optional - index of the order to be integrated, default is 0, so - the first order specified. - grid_pix : tuple or array, optional - If a tuple of 2 arrays is given, assume it is the lower and upper - integration ranges. If 1d-array, assume it is the center - of the pixels. If not given, the wavelength map and the psf map - of `i_order` will be used to compute a pixel grid. - grid_f_k : 1d array, optional - grid on which the convolved flux is projected. - Default is the wavelength grid for `i_order`. - convolved_spectrum : 1d array, optional - Convolved flux (f_k_c) to be integrated. If not given, `spectrum` - will be used (and convolved to `i_order` resolution) - spectrum : 1d array, optional - non-convolved flux (f_k, result of the `extract` method). - Not used if `convolved_spectrum` is specified. - bounds_error : bool, optional - passed to interp1d function to interpolate the convolved_spectrum. - Default is False - throughput : callable, optional - Spectral throughput for a given order (ì_ord). - Default is given by the list of throughput saved as - the attribute `t_list`. - kwargs : iterable, optional - If provided, will be passed to interp1d function. - Returns - ------- - pix_center, bin_val : array[float] - The pixel centers and the associated integrated values. - """ - # Take the value from the order if not given... - - # ... for the flux grid ... - if grid_f_k is None: - grid_f_k = self.wave_grid_c(i_order) - - # ... for the convolved flux ... - if convolved_spectrum is None: - # Use the spectrum (f_k) if the convolved_spectrum (f_k_c) not given. - if spectrum is None: - raise ValueError("`spectrum` or `convolved_spectrum` must be specified.") - else: - # Convolve the spectrum (f_k). - convolved_spectrum = self.kernels[i_order].dot(spectrum) - - # ... and for the pixel bins - if grid_pix is None: - pix_center, _ = self.grid_from_map(i_order) - - # Get pixels borders (plus and minus) - pix_p, pix_m = atoca_utils.get_wave_p_or_m(pix_center) - - else: # Else, unpack grid_pix - - # Could be a scalar or a 2-elements object) - if len(grid_pix) == 2: - - # 2-elements object, so we have the borders - pix_m, pix_p = grid_pix - - # Need to compute pixel center - d_pix = (pix_p - pix_m) - pix_center = grid_pix[0] + d_pix - else: - - # 1-element object, so we have the pix centers - pix_center = grid_pix - - # Need to compute the borders - pix_p, pix_m = atoca_utils.get_wave_p_or_m(pix_center) - - # Set the throughput to object attribute - # if not given - if throughput is None: - - # Need to interpolate - x, y = self.wave_grid, self.throughput[i_order] - throughput = interp1d(x, y) - - # Apply throughput on flux - convolved_spectrum = convolved_spectrum * throughput(grid_f_k) - - # Interpolate - kwargs['bounds_error'] = bounds_error - fct_f_k = interp1d(grid_f_k, convolved_spectrum, **kwargs) - - # Intergrate over each bins - bin_val = [] - for x1, x2 in zip(pix_m, pix_p): - - # Grid points that fall inside the pixel range - i_grid = (x1 < grid_f_k) & (grid_f_k < x2) - x_grid = grid_f_k[i_grid] - - # Add boundaries values to the integration grid - x_grid = np.concatenate([[x1], x_grid, [x2]]) - - # Integrate - integrand = fct_f_k(x_grid) * x_grid - bin_val.append(np.trapezoid(integrand, x_grid)) - - # Convert to array and return with the pixel centers. - return pix_center, np.array(bin_val) - - -class ExtractionEngine(_BaseOverlap): - """ - Run the ATOCA algorithm (Darveau-Bernier 2021, in prep). - - This version models the pixels of the detector using an oversampled trapezoidal integration. - """ - - def __init__(self, wave_map, trace_profile, *args, **kwargs): - """ - Parameters - ---------- - trace_profile : (N_ord, N, M) list or array of 2-D arrays - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - wave_map : (N_ord, N, M) list or array of 2-D arrays - A list or array of the central wavelength position for each - order on the detector. - It has to have the same (N, M) as `data`. - throughput : (N_ord [, N_k]) list of array or callable - A list of functions or array of the throughput at each order. - If callable, the functions depend on the wavelength. - If array, projected on `wave_grid`. - kernels : array, callable or sparse matrix - Convolution kernel to be applied on spectrum (f_k) for each orders. - Can be array of the shape (N_ker, N_k_c). - Can be a callable with the form f(x, x0) where x0 is - the position of the center of the kernel. In this case, it must - return a 1D array (len(x)), so a kernel value - for each pairs of (x, x0). If array or callable, - it will be passed to `convolution.get_c_matrix` function - and the `c_kwargs` can be passed to this function. - If sparse, the shape has to be (N_k_c, N_k) and it will - be used directly. N_ker is the length of the effective kernel - and N_k_c is the length of the spectrum (f_k) convolved. - data : (N, M) array_like, optional - A 2-D array of real values representing the detector image. - error : (N, M) array_like, optional - Estimate of the error on each pixel. Default is one everywhere. - mask : (N, M) array_like boolean, optional - Boolean Mask of the detector pixels to mask for every extraction. - Should not be related to a specific order (if so, use `mask_trace_profile` instead). - mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional - A list or array of the pixel that need to be used for extraction, - for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. - If not given, `threshold` will be applied on spatial profiles to define the masks. - orders : list, optional - List of orders considered. Default is orders = [1, 2] - wave_grid : (N_k) array_like, optional - The grid on which f(lambda) will be projected. - Default still has to be improved. - wave_bounds : list or array-like (N_ord, 2), optional - Boundary wavelengths covered by each orders. - Default is the wavelength covered by `wave_map`. - n_os : int, optional - Oversampling rate. If `wave_grid`is None, it will be used to - generate a grid. Default is 2. - threshold : float, optional: - The contribution of any order on a pixel is considered significant if - its estimated spatial profile is greater than this threshold value. - If it is not properly modeled (not covered by the wavelength grid), - it will be masked. Default is 1e-3. - c_kwargs : list of N_ord dictionaries or dictionary, optional - Inputs keywords arguments to pass to - `convolution.get_c_matrix` function for each order. - If dictionary, the same c_kwargs will be used for each order. - """ - - # Get wavelength at the boundary of each pixel - wave_p, wave_m = [], [] - for wave in wave_map: # For each order - lp, lm = atoca_utils.get_wave_p_or_m(wave) # Lambda plus or minus - # Make sure it is the good precision - wave_p.append(lp.astype(self.dtype)) - wave_m.append(lm.astype(self.dtype)) - - # Save values - self.wave_p, self.wave_m = wave_p, wave_m - - # Init upper class - super().__init__(wave_map, trace_profile, *args, **kwargs) - - def _get_lo_hi(self, grid, i_order): + def _get_lo_hi(self, grid, wave_p, wave_m, mask): """ Find the lowest (lo) and highest (hi) index of wave_grid for each pixels and orders. @@ -1581,8 +871,10 @@ def _get_lo_hi(self, grid, i_order): ---------- grid : array[float] Wave_grid to check. - i_order : int - Order to check values. + wave_p : array[float] + Wavelengths on the higher side of each pixel. + wave_m : array[float] + Wavelengths on the lower side of each pixel. Returns ------- @@ -1592,31 +884,21 @@ def _get_lo_hi(self, grid, i_order): log.debug('Computing lowest and highest indices of wave_grid.') - # Get needed attributes - mask = self.mask - - # ... order dependent attributes - attrs = ['wave_p', 'wave_m', 'mask_ord'] - wave_p, wave_m, mask_ord = self.get_attributes(*attrs, i_order=i_order) - - # Compute only for valid pixels - wave_p = wave_p[~mask] - wave_m = wave_m[~mask] - # Find lower (lo) index in the pixel - lo = np.searchsorted(grid, wave_m, side='right') + lo = np.searchsorted(grid, wave_m, side='right') - 1 # Find higher (hi) index in the pixel hi = np.searchsorted(grid, wave_p) - 1 - # Set invalid pixels for this order to lo=-1 and hi=-2 - ma = mask_ord[~mask] - lo[ma], hi[ma] = -1, -2 + # Set invalid pixels negative + lo[mask], hi[mask] = -1, -2 + + return lo, hi - return lo, hi def get_mask_wave(self, i_order): """Generate mask bounded by limits of wavelength grid + Parameters ---------- i_order : int @@ -1635,13 +917,12 @@ def get_mask_wave(self, i_order): wave_min = self.wave_grid[i_bnds[0]] wave_max = self.wave_grid[i_bnds[1] - 1] - mask = (wave_m < wave_min) | (wave_p > wave_max) + return (wave_m < wave_min) | (wave_p > wave_max) - return mask def get_w(self, i_order): - """Compute integration weights for each grid points and each pixels. - Depends on the order `n`. + """Compute integration weights 'k' for each grid point and pixel 'i'. + These depend on the type of interpolation used, i.e. the order `n`. Parameters ---------- @@ -1660,31 +941,27 @@ def get_w(self, i_order): log.debug('Computing weights and k.') - # Get needed attributes - wave_grid, mask = self.get_attributes('wave_grid', 'mask') - - # ... order dependent attributes + # get order dependent attributes attrs = ['wave_p', 'wave_m', 'mask_ord', 'i_bounds'] wave_p, wave_m, mask_ord, i_bnds = self.get_attributes(*attrs, i_order=i_order) # Use the convolved grid (depends on the order) - wave_grid = wave_grid[i_bnds[0]:i_bnds[1]] + wave_grid = self.wave_grid[i_bnds[0]:i_bnds[1]] # Compute the wavelength coverage of the grid d_grid = np.diff(wave_grid) - # Get lo hi - lo, hi = self._get_lo_hi(wave_grid, i_order) # Get indexes - # Compute only valid pixels - wave_p, wave_m = wave_p[~mask], wave_m[~mask] - ma = mask_ord[~mask] + wave_p, wave_m = wave_p[~self.mask], wave_m[~self.mask] + ma = mask_ord[~self.mask] + + # Get lo hi + lo, hi = self._get_lo_hi(wave_grid, wave_p, wave_m, ma) # Get indices # Number of used pixels n_i = len(lo) i = np.arange(n_i) - # Define first and last index of wave_grid - # for each pixel + # Define first and last index of wave_grid for each pixel k_first, k_last = -1 * np.ones(n_i), -1 * np.ones(n_i) # If lowest value close enough to the exact grid value, @@ -1713,12 +990,12 @@ def get_w(self, i_order): k_last[cond & ~ma] = hi[cond & ~ma] wave_p[cond & ~ma] = wave_grid[hi[cond & ~ma]] - # else, need hi_i + 1 - k_last[~cond & ~ma] = hi[~cond & ~ma] + 1 + # else, need hi_i + k_last[~cond & ~ma] = hi[~cond & ~ma] # Generate array of all k_i. Set to -1 if not valid - k_n, bad = atoca_utils.arange_2d(k_first, k_last + 1, dtype=int) - k_n[bad] = -1 + k_n = atoca_utils.arange_2d(k_first, k_last + 1) + bad = k_n == -1 # Number of valid k per pixel n_k = np.sum(~bad, axis=-1) @@ -1726,19 +1003,16 @@ def get_w(self, i_order): # Compute array of all w_i. Set to np.nan if not valid # Initialize w_n = np.zeros(k_n.shape, dtype=float) - #################### + #################### # 4 different cases #################### - #################### # Valid for every cases w_n[:, 0] = wave_grid[k_n[:, 1]] - wave_m w_n[i, n_k - 1] = wave_p - wave_grid[k_n[i, n_k - 2]] - ################## # Case 1, n_k == 2 - ################## case = (n_k == 2) & ~ma if case.any(): @@ -1757,9 +1031,7 @@ def get_w(self, i_order): part2 = d_grid[k_n[case, 0]] w_n[case, :] *= (part1 / part2)[:, None] - ################## # Case 2, n_k >= 3 - ################## case = (n_k >= 3) & ~ma if case.any(): @@ -1785,9 +1057,7 @@ def get_w(self, i_order): w_n[cond, n_ki - 1] *= (nume1 / deno) w_n[cond, n_ki - 2] += (nume1 * nume2 / deno) - ################## # Case 3, n_k >= 4 - ################## case = (n_k >= 4) & ~ma if case.any(): log.debug('n_k = 4 in get_w().') @@ -1796,9 +1066,7 @@ def get_w(self, i_order): w_n[case, n_ki - 2] += (wave_grid[k_n[case, n_ki - 2]] - wave_grid[k_n[case, n_ki - 3]]) - ################## # Case 4, n_k > 4 - ################## case = (n_k > 4) & ~ma if case.any(): log.debug('n_k > 4 in get_w().') diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 19a0144539..4c14749522 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -7,47 +7,38 @@ """ import numpy as np -from scipy.sparse import find, diags, csr_matrix -from scipy.sparse.linalg import spsolve +from numpy.polynomial import Polynomial +import warnings +from scipy.sparse import diags, csr_matrix +from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning from scipy.interpolate import interp1d, RectBivariateSpline, Akima1DInterpolator from scipy.optimize import minimize_scalar, brentq +from scipy.interpolate import make_interp_spline import logging log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -# ============================================================================== -# Code for generating indices on the oversampled wavelength grid. -# ============================================================================== - -def arange_2d(starts, stops, dtype=None): - """Create a 2D array containing a series of ranges. The ranges do not have - to be of equal length. +def arange_2d(starts, stops): + """ + Code for generating indices on the oversampled wavelength grid. + Creates a 2D array containing a series of ranges. + The ranges do not have to be of equal length. Parameters ---------- - starts : int or array[int] + starts : array[int] Start values for each range. - stops : int or array[int] + stops : array[int] End values for each range. - dtype : str - Type of the output values. Returns ------- - out : array[int] - 2D array of ranges. - mask : array[bool] - Mask indicating valid elements. + out : array[uint16] + 2D array of ranges with invalid values set to -1 """ - - # Ensure starts and stops are arrays. - starts = np.asarray(starts) - stops = np.asarray(stops) - - # Check input for starts and stops is valid. - if starts.shape != stops.shape and starts.shape != (): + if starts.shape != stops.shape: msg = ('Shapes of starts and stops are not compatible, ' 'they must either have the same shape or starts must be scalar.') log.critical(msg) @@ -58,34 +49,23 @@ def arange_2d(starts, stops, dtype=None): log.critical(msg) raise ValueError(msg) - # If starts was given as a scalar match its shape to stops. - if starts.shape == (): - starts = starts * np.ones_like(stops) - # Compute the length of each range. lengths = (stops - starts).astype(int) - # Initialize the output arrays. + # Initialize the output arrays with invalid value nrows = len(stops) ncols = np.amax(lengths) - out = np.ones((nrows, ncols), dtype=dtype) - mask = np.ones((nrows, ncols), dtype='bool') + out = np.ones((nrows, ncols), dtype=np.int16)*-1 # Compute the indices. for irow in range(nrows): out[irow, :lengths[irow]] = np.arange(starts[irow], stops[irow]) - mask[irow, :lengths[irow]] = False - - return out, mask - - -# ============================================================================== -# Code for converting to a sparse matrix and back. -# ============================================================================== + return out def sparse_k(val, k, n_k): - """Transform a 2D array `val` to a sparse matrix. + """ + Transform a 2D array `val` to a sparse matrix. Parameters ---------- @@ -114,56 +94,11 @@ def sparse_k(val, k, n_k): col = k[k >= 0] data = val[k >= 0] - mat = csr_matrix((data, (row, col)), shape=(n_i, n_k)) - - return mat - - -def unsparse(matrix, fill_value=np.nan): - """Convert a sparse matrix to a 2D array of values and a 2D array of position. - - Parameters - ---------- - matrix : csr_matrix - The input sparse matrix. - fill_value : float - Value to fill 2D array for undefined positions; default to np.nan - - Returns - ------ - out : 2d array - values of the matrix. The shape of the array is given by: - (matrix.shape[0], maximum number of defined value in a column). - col_out : 2d array - position of the columns. Same shape as `out`. - """ - - col, row, val = find(matrix.T) - n_row, n_col = matrix.shape - - good_rows, counts = np.unique(row, return_counts=True) - - # Define the new position in columns - i_col = np.indices((n_row, counts.max()))[1] - i_col = i_col[good_rows] - i_col = i_col[i_col < counts[:, None]] - - # Create outputs and assign values - col_out = np.ones((n_row, counts.max()), dtype=int) * -1 - col_out[row, i_col] = col - out = np.ones((n_row, counts.max())) * fill_value - out[row, i_col] = val - - return out, col_out - - -# ============================================================================== -# Code for building wavelength grids. -# ============================================================================== + return csr_matrix((data, (row, col)), shape=(n_i, n_k)) def get_wave_p_or_m(wave_map, dispersion_axis=1): - """ Compute upper and lower boundaries of a pixel map, + """Compute upper and lower boundaries of a pixel map, given the pixel central value. Parameters ---------- @@ -178,7 +113,7 @@ def get_wave_p_or_m(wave_map, dispersion_axis=1): The wavelength upper and lower boundaries of each pixel, given the central value. """ # Get wavelength boundaries of each pixels - wave_left, wave_right = get_wv_map_bounds(wave_map, dispersion_axis=dispersion_axis) + wave_left, wave_right = _get_wv_map_bounds(wave_map, dispersion_axis=dispersion_axis) # The outputs depend on the direction of the spectral axis. invalid = (wave_map == 0) @@ -194,7 +129,7 @@ def get_wave_p_or_m(wave_map, dispersion_axis=1): return wave_plus, wave_minus -def get_wv_map_bounds(wave_map, dispersion_axis=1): +def _get_wv_map_bounds(wave_map, dispersion_axis=1): """ Compute boundaries of a pixel map, given the pixel central value. Parameters ---------- @@ -209,6 +144,17 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): Wavelength of top edge for each pixel wave_bottom : array[float] Wavelength of bottom edge for each pixel + + Notes + ----- + Handling of invalid pixels may lead to unexpected results as follows: + Bad pixels are completely ignored when computing pixel-to-pixel differences, so + wv_map=[2,4,6,NaN,NaN,12,14,16] will give wave_top=[1,3,5,0,0,9,13,15] + because the difference at index 5 was calculated as 12-(12-6)/2=9, + i.e., as though index 2 and 5 were next to each other. + A human (or a smarter linear interpolation) would figure out the slope is 2 and + determine the value of wave_top[5] should most likely be 11. + This is found not to matter in practice for the current use cases. """ if dispersion_axis == 1: # Simpler to use transpose @@ -222,25 +168,25 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): wave_top = np.zeros_like(wave_map) wave_bottom = np.zeros_like(wave_map) - (n_row, n_col) = wave_map.shape + # for loop is needed to compute diff in just one spatial direction + # while skipping invalid values- not trivial to do with array comprehension even + # using masked arrays + n_col = wave_map.shape[1] for idx in range(n_col): wave_col = wave_map[:, idx] # Compute the change in wavelength for valid cols - idx_valid = np.isfinite(wave_col) - idx_valid &= (wave_col >= 0) + idx_valid = np.isfinite(wave_col) & (wave_col >= 0) wv_col_valid = wave_col[idx_valid] - delta_wave = np.diff(wv_col_valid) + delta_wave = np.diff(wv_col_valid) / 2 - # Init values - wv_col_top = np.zeros_like(wv_col_valid) - wv_col_bottom = np.zeros_like(wv_col_valid) + # handle edge effects using a constant-difference rule + delta_wave_top = np.insert(delta_wave,0,delta_wave[0]) + delta_wave_bottom = np.append(delta_wave,delta_wave[-1]) # Compute the wavelength values on the top and bottom edges of each pixel. - wv_col_top[1:] = wv_col_valid[:-1] + delta_wave / 2 # TODO check this logic. - wv_col_top[0] = wv_col_valid[0] - delta_wave[0] / 2 - wv_col_bottom[:-1] = wv_col_valid[:-1] + delta_wave / 2 - wv_col_bottom[-1] = wv_col_valid[-1] + delta_wave[-1] / 2 + wv_col_top = wv_col_valid - delta_wave_top + wv_col_bottom = wv_col_valid + delta_wave_bottom wave_top[idx_valid, idx] = wv_col_top wave_bottom[idx_valid, idx] = wv_col_bottom @@ -252,8 +198,9 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): return wave_top, wave_bottom -def oversample_grid(wave_grid, n_os=1): - """Create an oversampled version of the input 1D wavelength grid. +def oversample_grid(wave_grid, n_os): + """ + Create an oversampled version of the input 1D wavelength grid. Parameters ---------- @@ -270,46 +217,26 @@ def oversample_grid(wave_grid, n_os=1): The oversampled wavelength grid. """ - # Convert n_os to an array. + # Convert n_os to an array of size len(wave_grid) - 1. n_os = np.asarray(n_os) - - # n_os needs to have the dimension: len(wave_grid) - 1. if n_os.ndim == 0: - - # A scalar was given, repeat the value. n_os = np.repeat(n_os, len(wave_grid) - 1) - elif len(n_os) != (len(wave_grid) - 1): - # An array of incorrect size was given. msg = 'n_os must be a scalar or an array of size len(wave_grid) - 1.' log.critical(msg) raise ValueError(msg) - - # Grid intervals. - delta_wave = np.diff(wave_grid) - - # Initialize the new oversampled wavelength grid. - wave_grid_os = wave_grid.copy() - - # Iterate over oversampling factors to generate new grid points. - for i_os in range(1, n_os.max()): - - # Consider only intervals that are not complete yet. - mask = n_os > i_os - - # Compute the new grid points. - sub_grid = wave_grid[:-1][mask] + (i_os * delta_wave[mask] / n_os[mask]) - - # Add the grid points to the oversampled wavelength grid. - wave_grid_os = np.concatenate([wave_grid_os, sub_grid]) + + # Compute the oversampled grid. + intervals = 1/n_os + intervals = np.insert(np.repeat(intervals, n_os),0,0) + grid = np.cumsum(intervals) + wave_grid_os = np.interp(grid, np.arange(wave_grid.size), wave_grid) # Take only unique values and sort them. - wave_grid_os = np.unique(wave_grid_os) - - return wave_grid_os + return np.unique(wave_grid_os) -def extrapolate_grid(wave_grid, wave_range, poly_ord): +def _extrapolate_grid(wave_grid, wave_range, poly_ord=1): """Extrapolate the given 1D wavelength grid to cover a given range of values by fitting the derivative with a polynomial of a given order and using it to compute subsequent values at both ends of the grid. @@ -328,27 +255,42 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord): wave_grid_ext : array[float] The extrapolated 1D wavelength grid. """ + if wave_range[0] >= wave_range[-1]: + msg = 'wave_range must be in order [short, long].' + log.critical(msg) + raise ValueError(msg) + if wave_range[0] > wave_grid.max() or wave_range[-1] < wave_grid.min(): + msg = 'wave_range must overlap with wave_grid.' + log.critical(msg) + raise ValueError(msg) + if wave_range[0] > wave_grid.min() and wave_range[-1] < wave_grid.max(): + return wave_grid # Define delta_wave as a function of wavelength by fitting a polynomial. delta_wave = np.diff(wave_grid) - pars = np.polyfit(wave_grid[:-1], delta_wave, poly_ord) - f_delta = np.poly1d(pars) + f_delta = Polynomial.fit(wave_grid[:-1], delta_wave, poly_ord) + + # Set a minimum delta value to avoid running forever + min_delta = delta_wave.min()/10 # Extrapolate out-of-bound values on the left-side of the grid. grid_left = [] if wave_range[0] < wave_grid.min(): - # Compute the first extrapolated grid point. - grid_left = [wave_grid.min() - f_delta(wave_grid.min())] + # Initialize extrapolated grid with the first value of input grid. + # This point gets double-counted in the final grid, but then unique is called. + grid_left = [wave_grid.min(),] # Iterate until the end of wave_range is reached. while True: - next_val = grid_left[-1] - f_delta(grid_left[-1]) + next_delta = f_delta(grid_left[-1]) + next_val = grid_left[-1] - next_delta + grid_left.append(next_val) if next_val < wave_range[0]: break - else: - grid_left.append(next_val) + if next_delta < min_delta: + raise RuntimeError('Extrapolation failed to converge.') # Sort extrapolated vales (and keep only unique). grid_left = np.unique(grid_left) @@ -357,25 +299,26 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord): grid_right = [] if wave_range[-1] > wave_grid.max(): - # Compute the first extrapolated grid point. - grid_right = [wave_grid.max() + f_delta(wave_grid.max())] + # Initialize extrapolated grid with the last value of input grid. + # This point gets double-counted in the final grid, but then unique is called. + grid_right = [wave_grid.max(),] # Iterate until the end of wave_range is reached. while True: - next_val = grid_right[-1] + f_delta(grid_right[-1]) + next_delta = f_delta(grid_right[-1]) + next_val = grid_right[-1] + next_delta + grid_right.append(next_val) if next_val > wave_range[-1]: break - else: - grid_right.append(next_val) - - # Sort extrapolated vales (and keep only unique). + if next_delta < min_delta: + raise RuntimeError('Extrapolation failed to converge.') + + # Sort extrapolated values (and keep only unique) grid_right = np.unique(grid_right) # Combine the extrapolated sections with the original grid. - wave_grid_ext = np.concatenate([grid_left, wave_grid, grid_right]) - - return wave_grid_ext + return np.concatenate([grid_left, wave_grid, grid_right]) def _grid_from_map(wave_map, trace_profile): @@ -397,27 +340,26 @@ def _grid_from_map(wave_map, trace_profile): Column indices used. """ - # Use only valid columns. - mask = (trace_profile > 0).any(axis=0) & (wave_map > 0).any(axis=0) + # Use only valid values by setting weights to zero + trace_profile[trace_profile < 0] = 0 + trace_profile[wave_map <= 0] = 0 - # Get central wavelength using PSF as weights. - num = (trace_profile * wave_map).sum(axis=0) - denom = trace_profile.sum(axis=0) - center_wv = num[mask] / denom[mask] + # handle case where all values are invalid for a given wavelength + # np.average cannot process sum(weights) = 0, so set them to unity then set NaN afterward + bad_wls = np.sum(trace_profile, axis=0) == 0 + trace_profile[:,bad_wls] = 1 + center_wv = np.average(wave_map, weights=trace_profile, axis=0) + center_wv[bad_wls] = np.nan + center_wv = center_wv[~np.isnan(center_wv)] # Make sure the wavelength values are in ascending order. - sort = np.argsort(center_wv) - grid = center_wv[sort] - - icols, = np.where(mask) - return grid, icols[sort] + return np.sort(center_wv) -def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): +def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): """Define a wavelength grid by taking the central wavelength at each columns given by the center of mass of the spatial profile (so one wavelength per - column). If wave_range is outside of the wave_map, extrapolate with a - polynomial of order poly_ord. + column). If wave_range is outside of the wave_map, extrapolate. Parameters ---------- @@ -428,140 +370,64 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): wave_range : list[float] Minimum and maximum boundary of the grid to generate, in microns. Wave_range must include some wavelengths of wave_map. - n_os : int or list[int] - Oversampling of the grid compare to the pixel sampling. Can be - specified for each order if a list is given. If a single value is given - it will be used for all orders. - poly_ord : int - Order of the polynomial use to extrapolate the grid. + Note wave_range is exclusive, in the sense that wave_range[0] and wave_range[1] + will not be between min(output) and max(output). Instead, min(output) will be + the smallest value in the extrapolated grid that is greater than wave_range[0] + and max(output) will be the largest value that is less than wave_range[1]. + n_os : int + Oversampling of the grid compared to the pixel sampling. Returns ------- grid_os : array[float] Wavelength grid with oversampling applied """ + if wave_map.shape != trace_profile.shape: + msg = 'wave_map and trace_profile must have the same shape.' + log.critical(msg) + raise ValueError(msg) # Different treatment if wave_range is given. if wave_range is None: - out, _ = _grid_from_map(wave_map, trace_profile) + grid = _grid_from_map(wave_map, trace_profile) else: # Get an initial estimate of the grid. - grid, icols = _grid_from_map(wave_map, trace_profile) + grid = _grid_from_map(wave_map, trace_profile) - # Check if extrapolation needed. If so, out_col must be False. - extrapolate = (wave_range[0] < grid.min()) | (wave_range[1] > grid.max()) + # Extrapolate values out of the wv_map if needed + grid = _extrapolate_grid(grid, wave_range, poly_ord=1) - # Make sure grid is between the range - mask = (wave_range[0] <= grid) & (grid <= wave_range[-1]) + # Constrain grid to be within wave_range + grid = grid[grid>=wave_range[0]] + grid = grid[grid<=wave_range[-1]] # Check if grid and wv_range are compatible - if not mask.any(): + if len(grid) == 0: msg = "Invalid wave_map or wv_range." log.critical(msg) raise ValueError(msg) - grid, icols = grid[mask], icols[mask] - - # Extrapolate values out of the wv_map if needed - if extrapolate: - grid = extrapolate_grid(grid, wave_range, poly_ord) - - out = grid - # Apply oversampling - grid_os = oversample_grid(out, n_os=n_os) + return oversample_grid(grid, n_os=n_os) - return grid_os - - -def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=None): - """Create a wavelength grid specific to NIRISS SOSS mode observations. - Assumes 2 orders are given, use grid_from_map if only one order is needed. - - Parameters - ---------- - wave_maps : array[float] - Array containing the pixel wavelengths for order 1 and 2. - trace_profiles : array[float] - Array containing the spatial profiles for order 1 and 2. - wave_min : float - Minimum wavelength the output grid should cover. - wave_max : float - Maximum wavelength the output grid should cover. - n_os : int or list[int] - Oversampling of the grid compared to the pixel sampling. Can be - specified for each order if a list is given. If a single value is given - it will be used for all orders. - Returns - ------- - wave_grid_soss : array[float] - Wavelength grid optimized for extracting SOSS spectra across - order 1 and order 2. +def _trim_grids(all_grids, grid_range): """ - - # Check n_os input, default value is 2 for all orders. - if n_os is None: - n_os = [2, 2] - elif np.ndim(n_os) == 0: - n_os = [n_os, n_os] - elif len(n_os) != 2: - msg = (f"n_os must be an integer or a 2 element list or array of " - f"integers, got {n_os} instead") - log.critical(msg) - raise ValueError(msg) - - # Generate a wavelength range for each order. - # Order 1 covers the reddest part of the spectrum, - # so apply wave_max on order 1 and vice versa for order 2. - - # Take the most restrictive wave_min for order 1 - wave_min_o1 = np.maximum(wave_maps[0].min(), wave_min) - - # Take the most restrictive wave_max for order 2. - wave_max_o2 = np.minimum(wave_maps[1].max(), wave_max) - - # Now generate range for each orders - range_list = [[wave_min_o1, wave_max], - [wave_min, wave_max_o2]] - - # Use grid_from_map to construct separate oversampled grids for both orders. - wave_grid_o1 = grid_from_map(wave_maps[0], trace_profiles[0], - wave_range=range_list[0], n_os=n_os[0]) - wave_grid_o2 = grid_from_map(wave_maps[1], trace_profiles[1], - wave_range=range_list[1], n_os=n_os[1]) - - # Keep only wavelengths in order 1 that aren't covered by order 2. - mask = wave_grid_o1 > wave_grid_o2.max() - wave_grid_o1 = wave_grid_o1[mask] - - # Combine the order 1 and order 2 grids. - wave_grid_soss = np.concatenate([wave_grid_o1, wave_grid_o2]) - - # Sort values (and keep only unique). - wave_grid_soss = np.unique(wave_grid_soss) - - return wave_grid_soss - - -def _trim_grids(all_grids, grid_range=None): - """ Remove all parts of the grids that are not in range + Remove all parts of the grids that are not in range or that are already covered by grids with higher priority, i.e. preceding in the list. """ grids_trimmed = [] for grid in all_grids: # Remove parts of the grid that are not in the wavelength range - if grid_range is not None: - # Find where the limit values fall on the grid - i_min = np.searchsorted(grid, grid_range[0], side='right') - i_max = np.searchsorted(grid, grid_range[1], side='left') - # Make sure it is a valid value and take one grid point past the limit - # since the oversampling could squeeze some nodes near the limits - i_min = np.max([i_min - 1, 0]) - i_max = np.min([i_max, len(grid) - 1]) - # Trim the grid - grid = grid[i_min:i_max + 1] + i_min = np.searchsorted(grid, grid_range[0], side='right') + i_max = np.searchsorted(grid, grid_range[1], side='left') + # Make sure it is a valid value and take one grid point past the limit + # since the oversampling could squeeze some nodes near the limits + i_min = np.max([i_min - 1, 0]) + i_max = np.min([i_max, len(grid) - 1]) + # Trim the grid + grid = grid[i_min:i_max + 1] # Remove parts of the grid that are already covered if len(grids_trimmed) > 0: @@ -571,24 +437,19 @@ def _trim_grids(all_grids, grid_range=None): is_below = grid < np.min(conca_grid) is_above = grid > np.max(conca_grid) - # Do nothing yet if it surrounds the previous grid - if is_below.any() and is_above.any(): - msg = 'Grid surrounds another grid, better to split in 2 parts.' - log.warning(msg) - # Remove values already covered, but keep one # index past the limit - elif is_below.any(): + if is_below.any(): idx = np.max(np.nonzero(is_below)) idx = np.min([idx + 1, len(grid) - 1]) grid = grid[:idx + 1] - elif is_above.any(): + if is_above.any(): idx = np.min(np.nonzero(is_above)) idx = np.max([idx - 1, 0]) grid = grid[idx:] # If all is covered, no need to do it again, so empty grid. - else: + if not is_below.any() and not is_above.any(): grid = np.array([]) # Save trimmed grid @@ -597,9 +458,10 @@ def _trim_grids(all_grids, grid_range=None): return grids_trimmed -def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, - max_iter=10, rtol=10e-6, tol=0.0, max_total_size=1000000): - """Return an irregular oversampled grid needed to reach a +def make_combined_adaptive_grid(all_grids, all_estimates, grid_range, + max_iter=10, rtol=10e-6, max_total_size=1000000): + """ + Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. The grid is built by subdividing iteratively each intervals that did not reach the required precision. @@ -608,33 +470,32 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, Parameters ---------- - all_grid : list[array] - List of grid (arrays) to pass to adapt_grid, in order of importance. - all_estimate : list[callable] + all_grids : list[array] + List of grid (arrays) to pass to _adapt_grid, in order of importance. + all_estimates : list[callable] List of function (callable) to estimate the precision needed to oversample the grid. Must match the corresponding `grid` in `all_grid`. + grid_range : list[float] + Wavelength range the new grid should cover. max_iter : int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given - by delta_grid / 2^max_iter. Needs to be greater then zero. + by delta_grid / 2^max_iter. Needs to be greater than zero. Default is 10. rtol : float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - tol : float, optional - The desired absolute tolerance. Default is 0 to prioritize `rtol`. max_total_size : int, optional maximum size of the output grid. Default is 1 000 000. + Returns ------- os_grid : 1D array Oversampled combined grid which minimizes the integration error based on Romberg's method """ - # Save parameters for adapt_grid - kwargs = dict(max_iter=max_iter, rtol=rtol, tol=tol) # Remove unneeded parts of the grids - all_grids = _trim_grids(all_grids, grid_range=grid_range) + all_grids = _trim_grids(all_grids, grid_range) # Save native size of each grids (use later to adjust max_grid_size) all_sizes = [len(grid) for grid in all_grids] @@ -643,8 +504,6 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, combined_grid = np.array([]) # Init with empty array for i_grid, grid in enumerate(all_grids): - estimate = all_estimate[i_grid] - # Get the max_grid_size, considering the other grids # First, remove length already used max_grid_size = max_total_size - combined_grid.size @@ -653,19 +512,23 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, if i_size > i_grid: max_grid_size = max_grid_size - size # Make sure it is at least the size of the native grid. - kwargs['max_grid_size'] = np.max([max_grid_size, all_sizes[i_grid]]) + max_grid_size = np.max([max_grid_size, all_sizes[i_grid]]) # Oversample the grid based on tolerance required - grid, is_converged = adapt_grid(grid, estimate, **kwargs) + grid, is_converged = _adapt_grid(grid, + all_estimates[i_grid], + max_grid_size=max_grid_size, + max_iter=max_iter, + rtol=rtol) # Update grid sizes all_sizes[i_grid] = grid.size # Check convergence if not is_converged: - msg = 'Precision cannot be garanteed:' - if grid.size < kwargs['max_grid_size']: - msg += (f' smallest subdivision 1/{2 ** kwargs["max_iter"]:2.1e}' + msg = 'Precision cannot be guaranteed:' + if grid.size < max_grid_size: + msg += (f' smallest subdivision 1/{2 ** max_iter:2.1e}' f' was reached for grid index = {i_grid}') else: total_size = np.sum(all_sizes) @@ -674,25 +537,13 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, msg += f' = {total_size} was reached for grid index = {i_grid}.' log.warning(msg) - # Remove regions already covered in the output grid - if len(combined_grid) > 0: - idx_covered = (np.min(combined_grid) <= grid) - idx_covered &= (grid <= np.max(combined_grid)) - grid = grid[~idx_covered] - # Combine grids combined_grid = np.concatenate([combined_grid, grid]) # Sort values (and keep only unique). - combined_grid = np.unique(combined_grid) - - # Final trim to make sure it respects the range - if grid_range is not None: - idx_in_range = (grid_range[0] <= combined_grid) - idx_in_range &= (combined_grid <= grid_range[-1]) - combined_grid = combined_grid[idx_in_range] - - return combined_grid + # This is necessary because trim_grids allows lowest index of one grid to + # equal highest index of another grid. + return np.unique(combined_grid) def _romberg_diff(b, c, k): @@ -713,11 +564,7 @@ def _romberg_diff(b, c, k): R(n, m) : float or array[float] Difference between integral estimates of Rombergs method. """ - - tmp = 4.0**k - diff = (tmp * c - b) / (tmp - 1.0) - - return diff + return (4.0**k * c - b) / (4.0**k - 1.0) def _difftrap(fct, intervals, numtraps): @@ -786,116 +633,9 @@ def _difftrap(fct, intervals, numtraps): return ordsum -def get_n_nodes(grid, fct, divmax=10, tol=1.48e-4, rtol=1.48e-4): - """Refine parts of a grid to reach a specified integration precision - based on Romberg integration of a callable function or method. - Returns the number of nodes needed in each intervals of - the input grid to reach the specified tolerance over the integral - of `fct` (a function of one variable). - - Note: This function is based on scipy.integrate.quadrature.romberg. The - difference between it and the scipy version is that it is vectorized to deal - with multiple intervals separately. It also returns the number of nodes - needed to reached the required precision instead of returning the value of - the integral. - - Parameters - ---------- - grid : array[float] - Grid for integration. Each section of this grid is treated as a - separate integral; if grid has length N, N-1 integrals are optimized. - fct : callable - Function to be integrated. - divmax : int - Maximum order of extrapolation. - tol : float - The desired absolute tolerance. - rtol : float - The desired relative tolerance. - - Returns - ------- - n_grid : array[int] - Number of nodes needed on each distinct intervals in the grid to reach - the specified tolerance. - residual : array[float] - Estimate of the error in each intervals. Same length as n_grid. +def _estim_integration_err(grid, fct): """ - - # Initialize some variables. - n_intervals = len(grid) - 1 - i_bad = np.arange(n_intervals) - n_grid = np.repeat(-1, n_intervals) - residual = np.repeat(np.nan, n_intervals) - - # Change the 1D grid into a 2D set of intervals. - intervals = np.array([grid[:-1], grid[1:]]) - intrange = np.diff(grid) - err = np.inf - - # First estimate without subdivision. - numtraps = 1 - ordsum = _difftrap(fct, intervals, numtraps) - results = intrange * ordsum - last_row = [results] - - for i_div in range(1, divmax + 1): - - # Increase the number of trapezoids by factors of 2. - numtraps *= 2 - - # Evaluate trapz integration for intervals that are not converged. - ordsum += _difftrap(fct, intervals[:, i_bad], numtraps) - row = [intrange[i_bad] * ordsum / numtraps] - - # Compute Romberg for each of the computed sub grids. - for k in range(i_div): - romb_k = _romberg_diff(last_row[k], row[k], k + 1) - row = np.vstack([row, romb_k]) - - # Save R(n,n) and R(n-1, n-1) from Romberg method. - results = row[i_div] - lastresults = last_row[i_div - 1] - - # Estimate error. - err = np.abs(results - lastresults) - - # Find intervals that are converged. - conv = (err < tol) | (err < rtol * np.abs(results)) - - # Save number of nodes for these intervals. - n_grid[i_bad[conv]] = numtraps - - # Save residuals. - residual[i_bad] = err - - # Stop if all intervals have converged. - if conv.all(): - break - - # Find intervals not converged. - i_bad = i_bad[~conv] - - # Save last_row and ordsum for the next iteration for non-converged - # intervals. - ordsum = ordsum[~conv] - last_row = row[:, ~conv] - - else: - # Warn that convergence is not reached everywhere. - log.warning(f"divmax {divmax} exceeded. Latest difference = {err.max()}") - - # Make sure all values of n_grid where assigned during the process. - if (n_grid == -1).any(): - msg = f"Values where not assigned at grid position: {np.where(n_grid == -1)}" - log.critical(msg) - raise ValueError(msg) - - return n_grid, residual - - -def estim_integration_err(grid, fct): - """Estimate the integration error on each intervals + Estimate the integration error on each intervals of the grid using 1rst order Romberg integration. Parameters @@ -909,7 +649,10 @@ def estim_integration_err(grid, fct): Returns ------- - err, rel_err: error and relative error of each integrations, with length = length(grid) - 1 + err: + absolute error of each integration, with length = length(grid) - 1 + rel_err: + relative error of each integration, with length = length(grid) - 1 """ # Change the 1D grid into a 2D set of intervals. @@ -939,8 +682,9 @@ def estim_integration_err(grid, fct): return err, rel_err -def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): - """Return an irregular oversampled grid needed to reach a +def _adapt_grid(grid, fct, max_grid_size, max_iter=10, rtol=10e-6, atol=1e-6): + """ + Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. The grid is built by subdividing iteratively each intervals that did not reach the required precision. @@ -955,6 +699,8 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): optimized. fct: callable Function to be integrated. Must be a function of `grid` + max_grid_size: int, required. + maximum size of the output grid. max_iter: int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given @@ -962,10 +708,9 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): Default is 10. rtol: float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - tol: float, optional - The desired absolute tolerance. Default is 0 to prioritize `rtol`. - max_grid_size: int, optional - maximum size of the output grid. Default is None, so no constraint. + atol: float, optional + The desired absolute tolerance. Default is 1e-6. + Returns ------- os_grid : 1D array @@ -973,6 +718,7 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): Romberg's method convergence_flag: bool Whether the estimated tolerance was reach everywhere or not. + See Also -------- scipy.integrate.quadrature.romberg @@ -981,30 +727,27 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): [1] 'Romberg's method' https://en.wikipedia.org/wiki/Romberg%27s_method """ - # No limit of max_grid_size not given - if max_grid_size is None: - max_grid_size = np.inf - # Init some flags max_size_reached = (grid.size >= max_grid_size) + if max_size_reached: + raise ValueError('max_grid_size is too small for the input grid.') - # Iterate until precision is reached of max_iter + # Iterate until precision is reached or max_iter for _ in range(max_iter): # Estimate error using Romberg integration - err, rel_err = estim_integration_err(grid, fct) + abs_err, rel_err = _estim_integration_err(grid, fct) # Check where precision is reached - converged = (err < tol) | (rel_err < rtol) + converged = (rel_err < rtol) | (abs_err < atol) is_converged = converged.all() - # Check if max grid size was reached + # Stop iterating if max grid size was reached if max_size_reached or is_converged: - # Then stop iteration break # Intervals that didn't reach the precision will be subdivided - n_oversample = np.full(err.shape, 2, dtype=int) + n_oversample = np.full(rel_err.shape, 2, dtype=int) # No subdivision for the converged ones n_oversample[converged] = 1 @@ -1013,6 +756,8 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): # to reach the maximum size os_grid_size = n_oversample.sum() if os_grid_size > max_grid_size: + max_size_reached = True + # How many nodes can be added to reach max? n_nodes_remaining = max_grid_size - grid.size @@ -1020,85 +765,79 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): idx_largest_err = np.argsort(rel_err)[-n_nodes_remaining:] # Build new oversample array and assign only largest errors - n_oversample = np.ones(err.shape, dtype=int) + n_oversample = np.ones(rel_err.shape, dtype=int) n_oversample[idx_largest_err] = 2 - # Flag to stop iterations - max_size_reached = True - - # Generate oversampled grid (subdivide) + # Generate oversampled grid (subdivide). Returns sorted and unique grid. grid = oversample_grid(grid, n_os=n_oversample) - # Make sure sorted and unique. - grid = np.unique(grid) - return grid, is_converged -# ============================================================================== -# Code for handling the throughput and kernels. -# ============================================================================== - - -class ThroughputSOSS(interp1d): - - def __init__(self, wavelength, throughput): - """Create an instance of scipy.interpolate.interp1d to handle the - throughput values. +def ThroughputSOSS(wavelength, throughput): + """ + Parameters + ---------- + wavelength : array[float] + A wavelength array. + throughput : array[float] + The throughput values corresponding to the wavelengths. - Parameters - ---------- - wavelength : array[float] - A wavelength array. - throughput : array[float] - The throughput values corresponding to the wavelengths. - """ + Returns + ------- + interpolator : callable + A function that interpolates the throughput values. Accepts an array + of wavelengths and returns the interpolated throughput values. - # Interpolate - super().__init__(wavelength, throughput, kind='cubic', fill_value=0, - bounds_error=False) + Notes + ----- + Throughput is always zero at min, max of wavelength. + """ + wavelength = np.sort(wavelength) + wl_min, wl_max = np.min(wavelength), np.max(wavelength) + throughput[0] = 0.0 + throughput[-1] = 0.0 + interp = make_interp_spline(wavelength, throughput, k=3, bc_type=("clamped", "clamped")) + def interpolator(wv): + wv = np.clip(wv, wl_min, wl_max) + return interp(wv) + return interpolator -class WebbKernel: # TODO could probably be cleaned-up somewhat, may need further adjustment. +class WebbKernel: - def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels may need to be flipped? - bounds_error=False, fill_value="extrapolate"): - """A handler for the kernel values. + def __init__(self, wave_kernels, kernels, wave_trace, n_pix): + """ + Initialize the kernel object. Parameters ---------- wave_kernels : array[float] - Kernels for wavelength array. + Wavelength array for the kernel. Must have same shape as kernels. kernels : array[float] - Kernels for throughput array. - wave_map : array[float] - Wavelength map of the detector. Since WebbPSF returns kernels in - the pixel space, we need a wave_map to convert to wavelength space. - n_os : int - Oversampling of the kernels. + Kernel for throughput array. + Dimensions are (wavelength, oversampled pixels). + Center (~max throughput) of the kernel is at the center of the 2nd axis. + wave_trace : array[float] + 1-D trace of the detector central wavelengths for the given order. + Since WebbPSF returns kernels in the pixel space, this is used to + convert to wavelength space. n_pix : int - Length of the kernels in pixels. - bounds_error : bool - If True, raise an error when trying to call the function out of the - interpolation range. If False, the values will be extrapolated. - fill_value : str - How to extrapolate when needed. Only default "extrapolate" - currently implemented. + Number of detector pixels spanned by the kernel. Second axis of kernels + has shape (n_os * n_pix) - (n_os - 1), where n_os is the + spectral oversampling factor. """ + self.n_pix = n_pix - # Mask where wv_map is equal to 0 - wave_map = np.ma.array(wave_map, mask=(wave_map == 0)) - - # Force wv_map to have the red wavelengths - # at the end of the detector - if np.diff(wave_map, axis=-1).mean() < 0: - wave_map = np.flip(wave_map, axis=-1) + # Mask where trace is equal to 0 + wave_trace = np.ma.array(wave_trace, mask=(wave_trace == 0)) - # Number of columns - ncols = wave_map.shape[-1] + # Force trace to have the red wavelengths at the end of the detector + if np.diff(wave_trace).mean() < 0: + wave_trace = np.flip(wave_trace) - # Create oversampled pixel position array - pixels = np.arange(-(n_pix // 2), n_pix // 2 + (1 / n_os), (1 / n_os)) + # Create oversampled pixel position array. Center index should have value 0. + self.pixels = np.linspace(-(n_pix // 2), n_pix // 2, wave_kernels.shape[0]) # `wave_kernel` has only the value of the central wavelength # of the kernel at each points because it's a function @@ -1106,27 +845,25 @@ def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels wave_center = wave_kernels[0, :] # Use the wavelength solution to create a mapping between pixels and wavelengths - # First find the all kernels that fall on the detector. - wave_min = np.amin(wave_map[wave_map > 0]) - wave_max = np.amax(wave_map[wave_map > 0]) + wave_min = np.amin(wave_trace[wave_trace > 0]) + wave_max = np.amax(wave_trace[wave_trace > 0]) i_min = np.searchsorted(wave_center, wave_min) i_max = np.searchsorted(wave_center, wave_max) - 1 - # Use the next kernels at each extremities to define the - # boundaries of the interpolation to use in the class - # RectBivariateSpline (at the end) + # i_min, i_max correspond to the min, max indices of the kernel that are represented + # on the detector. Use those to define the boundaries of the interpolation to use + # in the RectBivariateSpline interpolation bbox = [None, None, wave_center[np.maximum(i_min - 1, 0)], wave_center[np.minimum(i_max + 1, len(wave_center) - 1)]] - ####################### # Keep only kernels that fall on the detector. - kernels = kernels[:, i_min:i_max + 1].copy() + self.kernels = kernels[:, i_min:i_max + 1].copy() wave_kernels = wave_kernels[:, i_min:i_max + 1].copy() - wave_center = np.array(wave_kernels[0, :]) + wave_center = np.array(wave_kernels[0]) # Save minimum kernel value (greater than zero) - kernels_min = np.min(kernels[(kernels > 0.0)]) + self.min_value = np.min(self.kernels[(self.kernels > 0.0)]) # Then find the pixel closest to each kernel center # and use the surrounding pixels (columns) @@ -1141,49 +878,42 @@ def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels wv = np.ma.masked_all(i_surround.shape) # Closest pixel wv - i_row, i_col = np.unravel_index( - np.argmin(np.abs(wave_map - wv_c)), wave_map.shape - ) + i_col = np.argmin(np.abs(wave_trace - wv_c)) # Update wavelength center value # (take the nearest pixel center value) - wave_center[i_cen] = wave_map[i_row, i_col] + wave_center[i_cen] = wave_trace[i_col] # Surrounding columns index = i_col + i_surround # Make sure it's on the detector - i_good = (index >= 0) & (index < ncols) + i_good = (index >= 0) & (index < wave_trace.size) # Assign wv values - wv[i_good] = wave_map[i_row, index[i_good]] + wv[i_good] = wave_trace[index[i_good]] # Fit n=1 polynomial - poly_i = np.polyfit(i_surround[~wv.mask], wv[~wv.mask], 1) + f = Polynomial.fit(i_surround[~wv.mask], wv[~wv.mask], 1) + poly_i = f.coef[::-1] # Reverse order to match old behavior from legacy np.polyval # Project on os pixel grid - wave_kernels[:, i_cen] = np.poly1d(poly_i)(pixels) + wave_kernels[:, i_cen] = f(self.pixels) # Save coeffs poly.append(poly_i) - # Save attributes - self.n_pix = n_pix - self.n_os = n_os + # Save computed attributes self.wave_kernels = wave_kernels - self.kernels = kernels - self.pixels = pixels self.wave_center = wave_center self.poly = np.array(poly) - self.fill_value = fill_value - self.bounds_error = bounds_error - self.min_value = kernels_min - # 2d Interpolate - self.f_ker = RectBivariateSpline(pixels, wave_center, kernels, bbox=bbox) + # 2D Interpolate + self.f_ker = RectBivariateSpline(self.pixels, self.wave_center, self.kernels, bbox=bbox) def __call__(self, wave, wave_c): - """Returns the kernel value, given the wavelength and the kernel central - wavelength. + """ + Returns the kernel value, given the wavelength and the kernel central + wavelength. Wavelengths that are out of bounds will be extrapolated. Parameters ---------- @@ -1191,6 +921,7 @@ def __call__(self, wave, wave_c): Wavelength where the kernel is projected. wave_c : array[float] Central wavelength of the kernel. + Returns ------- out : array[float] @@ -1199,33 +930,16 @@ def __call__(self, wave, wave_c): wave_center = self.wave_center poly = self.poly - fill_value = self.fill_value - bounds_error = self.bounds_error n_wv_c = len(wave_center) - f_ker = self.f_ker - n_pix = self.n_pix - min_value = self.min_value - # ################################# - # First, convert wv value in pixels - # using a linear interpolation - # ################################# + # First, convert wavelength value into pixels using self.poly to interpolate # Find corresponding interval i_wv_c = np.searchsorted(wave_center, wave_c) - 1 - # Deal with values out of bounds - if bounds_error: - message = "Value of wv center out of interpolation range" - log.critical(message) - raise ValueError(message) - elif fill_value == "extrapolate": - i_wv_c[i_wv_c < 0] = 0 - i_wv_c[i_wv_c >= (n_wv_c - 1)] = n_wv_c - 2 - else: - message = f"`fill_value`={fill_value} is not an valid option." - log.critical(message) - raise ValueError(message) + # Extrapolate values out of bounds + i_wv_c[i_wv_c < 0] = 0 + i_wv_c[i_wv_c >= (n_wv_c - 1)] = n_wv_c - 2 # Compute coefficients that interpolate along wv_centers d_wv_c = wave_center[i_wv_c + 1] - wave_center[i_wv_c] @@ -1241,83 +955,24 @@ def __call__(self, wave, wave_c): # Compute pixel values pix = a_pix * wave + b_pix - # ###################################### - # Second, compute kernel value on the - # interpolation grid (pixel x wv_center) - # ###################################### - - webbker = f_ker(pix, wave_c, grid=False) + # Second, compute kernel value on the interpolation grid (pixel x wv_center) + webbker = self.f_ker(pix, wave_c, grid=False) - # Make sure it's not negative and greater than the min value - webbker = np.clip(webbker, min_value, None) - - # and put out-of-range values to zero. - webbker[pix > n_pix // 2] = 0 - webbker[pix < -(n_pix // 2)] = 0 + # Make sure it's not negative and greater than the min value, set pixels outside range to zero + webbker = np.clip(webbker, self.min_value, None) + webbker[pix > self.n_pix // 2] = 0 + webbker[pix < -(self.n_pix // 2)] = 0 return webbker -# ============================================================================== -# Code for building the convolution matrix (c matrix). -# ============================================================================== - - -def gaussians(x, x0, sig, amp=None): - """Gaussian function +def _constant_kernel_to_2d(c, grid_range): + """Build a 2d kernel array with a constant 1D kernel as input Parameters ---------- - x : array[float] - Array of points over which gaussian to be defined. - x0 : float - Center of the gaussian. - sig : float - Standard deviation of the gaussian. - amp : float - Value of the gaussian at the center. - - Returns - ------- - values : array[float] - Array of gaussian values for input x. - """ - - # Amplitude term - if amp is None: - amp = 1. / np.sqrt(2. * np.pi * sig**2.) - - values = amp * np.exp(-0.5 * ((x - x0) / sig) ** 2.) - - return values - - -def fwhm2sigma(fwhm): - """Convert a full width half max to a standard deviation, assuming a gaussian - - Parameters - ---------- - fwhm : float - Full-width half-max of a gaussian. - - Returns - ------- - sigma : float - Standard deviation of a gaussian. - """ - - sigma = fwhm / np.sqrt(8. * np.log(2.)) - - return sigma - - -def to_2d(kernel, grid_range): - """ Build a 2d kernel array with a constant 1D kernel (input) - - Parameters - ---------- - kernel : array[float] - Input 1D kernel. + c : float or size-1 np.ndarray + Constant value to expand into a 2-D kernel grid_range : list[int] Indices over which convolution is defined on grid. @@ -1335,9 +990,7 @@ def to_2d(kernel, grid_range): n_k_c = b - a # Return a 2D array with this length - kernel_2d = np.tile(kernel, (n_k_c, 1)).T - - return kernel_2d + return np.tile(np.atleast_1d(c), (n_k_c, 1)).T def _get_wings(fct, grid, h_len, i_a, i_b): @@ -1407,8 +1060,9 @@ def _get_wings(fct, grid, h_len, i_a, i_b): return left, right -def trpz_weight(grid, length, shape, i_a, i_b): - """Compute weights due to trapezoidal integration +def _trpz_weight(grid, length, shape, i_a, i_b): + """ + Compute weights due to trapezoidal integration Parameters ---------- @@ -1456,8 +1110,9 @@ def trpz_weight(grid, length, shape, i_a, i_b): return out -def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): - """Build a compact kernel 2d array based on a kernel function +def _fct_to_array(fct, grid, grid_range, thresh): + """ + Build a compact kernel 2d array based on a kernel function and a grid to project the kernel Parameters @@ -1473,11 +1128,9 @@ def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): Indices of the grid where to apply the convolution. Once the convolution applied, the convolved grid will be equal to grid[grid_range[0]:grid_range[1]]. - thresh : float, optional - Threshold to cut the kernel wings. If `length` is specified, - `thresh` will be ignored. - length : int, optional - Length of the kernel. Must be odd. + thresh : float, required + Threshold to define the maximum length of the kernel. + Truncate when `kernel` < `thresh`. Returns ------- @@ -1488,153 +1141,51 @@ def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): # Assign range where the convolution is defined on the grid i_a, i_b = grid_range - # Init with the value at kernel's center - out = fct(grid, grid)[i_a:i_b] - - # Add wings - if length is None: - # Generate a 2D array of the grid iteratively until - # thresh is reached everywhere. - - # Init parameters - length = 1 - h_len = 0 # Half length - - # Add value on each sides until thresh is reached - while True: - # Already update half-length - h_len += 1 - - # Compute next left and right ends of the kernel - left, right = _get_wings(fct, grid, h_len, i_a, i_b) - - # Check if they are all below threshold. - if (left < thresh).all() and (right < thresh).all(): - break # Stop iteration - else: - # Update kernel length - length += 2 - - # Set value to zero if smaller than threshold - left[left < thresh] = 0. - right[right < thresh] = 0. - - # add new values to output - out = np.vstack([left, out, right]) + # Init 2-D array with first dimension length 1, with the value at kernel's center + out = fct(grid, grid)[i_a:i_b][np.newaxis,...] - # Weights due to integration (from the convolution) - weights = trpz_weight(grid, length, out.shape, i_a, i_b) + # Add wings: Generate a 2D array of the grid iteratively until + # thresh is reached everywhere. + length = 1 + h_len = 0 # Half length + while True: + h_len += 1 - elif (length % 2) == 1: # length needs to be odd - # Generate a 2D array of the grid iteratively until - # specified length is reached. + # Compute next left and right ends of the kernel + left, right = _get_wings(fct, grid, h_len, i_a, i_b) - # Compute number of half-length - n_h_len = (length - 1) // 2 + # Check if they are all below threshold. + if (left < thresh).all() and (right < thresh).all(): + break # Stop iteration + else: + # Update kernel length + length += 2 - # Simply iterate to compute needed wings - for h_len in range(1, n_h_len + 1): - # Compute next left and right ends of the kernel - left, right = _get_wings(fct, grid, h_len, i_a, i_b) + # Set value to zero if smaller than threshold + left[left < thresh] = 0. + right[right < thresh] = 0. - # Add new kernel values + # add new values to output out = np.vstack([left, out, right]) - # Weights due to integration (from the convolution) - weights = trpz_weight(grid, length, out.shape, i_a, i_b) - - else: - msg = "`length` provided to `fct_to_array` must be odd." - log.critical(msg) - raise ValueError(msg) - - kern_array = (out * weights) - return kern_array - + # Weights due to integration (from the convolution) + weights = _trpz_weight(grid, length, out.shape, i_a, i_b) -def cut_ker(ker, n_out=None, thresh=None): - """Apply a cut on the convolution matrix boundaries. + return (out * weights) - Parameters - ---------- - ker : array[float] - convolution kernel in compact form, so - shape = (N_ker, N_k_convolved) - n_out : int, list[int] or tuple[int] - Number of kernel's grid point to keep on the boundaries. - If an int is given, the same number of points will be - kept on each boundaries of the kernel (left and right). - If 2 elements are given, it corresponds to the left and right - boundaries. - thresh : float - threshold used to determine the boundaries cut. - If n_out is specified, this is ignored. - Returns - ------ - ker : array[float] - The same kernel matrix as the input ker, but with the cut applied. +def _sparse_c(ker, n_k, i_zero): """ - - # Assign kernel length and number of kernels - n_ker, n_k_c = ker.shape - - # Assign half-length of the kernel - h_len = (n_ker - 1) // 2 - - # Determine n_out with thresh if not given - if n_out is None: - - if thresh is None: - # No cut to apply - return ker - else: - # Find where to cut the kernel according to thresh - i_left = np.where(ker[:, 0] >= thresh)[0][0] - i_right = np.where(ker[:, -1] >= thresh)[0][-1] - - # Make sure it is on the good wing. Take center if not. - i_left = np.minimum(i_left, h_len) - i_right = np.maximum(i_right, h_len) - - # Else, unpack n_out - else: - # Could be a scalar or a 2-elements object) - try: - i_left, i_right = n_out - except TypeError: - i_left, i_right = n_out, n_out - - # Find the position where to cut the kernel - # Make sure it is not out of the kernel grid, - # so i_left >= 0 and i_right <= len(kernel) - i_left = np.maximum(h_len - i_left, 0) - i_right = np.minimum(h_len + i_right, n_ker - 1) - - # Apply the cut - for i_k in range(0, i_left): - # Add condition in case the kernel is larger - # than the grid where it's projected. - if i_k < n_k_c: - ker[:i_left - i_k, i_k] = 0 - - for i_k in range(i_right + 1 - n_ker, 0): - # Add condition in case the kernel is larger - # than the grid where it's projected. - if -i_k <= n_k_c: - ker[i_right - n_ker - i_k:, i_k] = 0 - - return ker - - -def sparse_c(ker, n_k, i_zero=0): - """Convert a convolution kernel in compact form (N_ker, N_k_convolved) + Convert a convolution kernel in compact form (N_ker, N_k_convolved) to sparse form (N_k_convolved, N_k) + TODO: why is all the formalism for defining the diagonal necessary? why can't csr_matrix be + called directly? there must be a reason, but add documentation! + Parameters ---------- ker : array[float] - Convolution kernel in compact form, with shape (N_kernel, N_kc) + Convolution kernel with shape (N_kernel, N_kc) n_k : int Length of the original grid i_zero : int @@ -1652,7 +1203,7 @@ def sparse_c(ker, n_k, i_zero=0): # Algorithm works for odd kernel grid if n_ker % 2 != 1: - err_msg = "Length of the convolution kernel given to sparse_c should be odd." + err_msg = "Length of the convolution kernel given to _sparse_c should be odd." log.critical(err_msg) raise ValueError(err_msg) @@ -1673,71 +1224,48 @@ def sparse_c(ker, n_k, i_zero=0): offset.append(i_k) # Build convolution matrix - matrix = diags(diag_val, offset, shape=(n_k_c, n_k), format="csr") + return diags(diag_val, offset, shape=(n_k_c, n_k), format="csr") - return matrix - -def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, - sparse=True, n_out=None, thresh_out=None, **kwargs): - """Return a convolution matrix - Can return a sparse matrix (N_k_convolved, N_k) - or a matrix in the compact form (N_ker, N_k_convolved). +def get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5): + """ + Return a convolution matrix + Returns a sparse matrix (N_k_convolved, N_k). N_k is the length of the grid on which the convolution will be applied, N_k_convolved is the length of the grid after convolution and N_ker is the maximum length of - the kernel. If the default sparse matrix option is chosen, - the convolution can be applied on an array f | f = fct(grid) + the kernel. + The convolution can be applied on an array f | f = fct(grid) by a simple matrix multiplication: f_convolved = c_matrix.dot(f) Parameters ---------- - kernel: ndarray (1D or 2D), callable + kernel: ndarray (2D) or callable Convolution kernel. Can be already 2D (N_ker, N_k_convolved), giving the kernel for each items of the convolved grid. - Can be 1D (N_ker), so the kernel is the same. Can be a callable + Can be a callable with the form f(x, x0) where x0 is the position of the center of the kernel. Must return a 1D array (len(x)), so a kernel value - for each pairs of (x, x0). If kernel is callable, the additional - kwargs `thresh` and `length` will be used to project the kernel. - grid: one-d-array: + for each pairs of (x, x0). + grid: 1D np.array: The grid on which the convolution will be applied. For example, if C is the convolution matrix, f_convolved = C.f(grid) - bounds: 2-elements object + i_bounds: 2-elements object, optional, default None. The bounds of the grid on which the convolution is defined. For example, if bounds = (a,b), then grid_convolved = grid[a <= grid <= b]. - It dictates also the dimension of f_convolved - sparse: bool, optional - return a sparse matrix (N_k_convolved, N_k) if True. - return a matrix (N_ker, N_k_convolved) if False. - n_out: integer or 2-integer object, optional - Specify how to deal with the ends of the convolved grid. - `n_out` points will be used outside from the convolved - grid. Can be different for each ends if 2-elements are given. - thresh_out: float, optional - Specify how to deal with the ends of the convolved grid. - Points with a kernel value less then `thresh_out` will - not be used outside from the convolved grid. + It dictates also the dimension of f_convolved. + If None, the convolution is defined on the whole grid. thresh: float, optional Only used when `kernel` is callable to define the maximum length of the kernel. Truncate when `kernel` < `thresh` - length: int, optional - Only used when `kernel` is callable to define the maximum - length of the kernel. """ # Define range where the convolution is defined on the grid. - # If `i_bounds` is not specified, try with `bounds`. if i_bounds is None: - - if bounds is None: - a, b = 0, len(grid) - else: - a = np.min(np.where(grid >= bounds[0])[0]) - b = np.max(np.where(grid <= bounds[1])[0]) + 1 + a, b = 0, len(grid) else: # Make sure it is absolute index, not relative @@ -1747,100 +1275,27 @@ def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, a, b = i_bounds - # Generate a 2D kernel depending on the input + # Generate a 2D kernel of shape (N_kernel x N_kc) if callable(kernel): - kernel = fct_to_array(kernel, grid, [a, b], **kwargs) - elif kernel.ndim == 1: - kernel = to_2d(kernel, [a, b]) + kernel = _fct_to_array(kernel, grid, [a, b], thresh) - if kernel.ndim != 2: + elif kernel.size == 1: + kernel = _constant_kernel_to_2d(kernel, [a, b]) + + elif kernel.ndim != 2: msg = ("Input kernel to get_c_matrix must be callable or" - " array with one or two dimensions.") + "2-dimensional array.") log.critical(msg) raise ValueError(msg) - # Kernel should now be a 2-D array (N_kernel x N_kc) - - # Normalize if specified - if norm: - kernel = kernel / np.nansum(kernel, axis=0) - - # Apply cut for kernel at boundaries - kernel = cut_ker(kernel, n_out, thresh_out) - - if sparse: - # Convert to a sparse matrix. - kernel = sparse_c(kernel, len(grid), a) - - return kernel - - -class NyquistKer: - """Define a gaussian convolution kernel at the nyquist - sampling. For a given point on the grid x_i, the kernel - is given by a gaussian with - FWHM = n_sampling * (dx_(i-1) + dx_i) / 2. - The FWHM is computed for each elements of the grid except - the extremities (not defined). We can then generate FWHM as - a function of the grid and interpolate/extrapolate to get - the kernel as a function of its position relative to the grid. - """ - - def __init__(self, grid, n_sampling=2, bounds_error=False, - fill_value="extrapolate", **kwargs): - """Parameters - ---------- - grid : array[float] - Grid used to define the kernels - n_sampling : int, optional - Sampling of the grid. - bounds_error : bool - Argument for `interp1d` to get FWHM as a function of the grid. - fill_value : str - Argument for `interp1d` to choose fill method to get FWHM. - """ - - # Delta grid - d_grid = np.diff(grid) - - # The full width half max is n_sampling - # times the mean of d_grid - fwhm = (d_grid[:-1] + d_grid[1:]) / 2 - fwhm *= n_sampling - - # What we really want is sigma, not FWHM - sig = fwhm2sigma(fwhm) - - # Now put sigma as a function of the grid - sig = interp1d(grid[1:-1], sig, bounds_error=bounds_error, - fill_value=fill_value, **kwargs) - - self.fct_sig = sig - - def __call__(self, x, x0): - """Parameters - ---------- - x : array[float] - position where the kernel is evaluated - x0 : array[float] - position of the kernel center for each x. - - Returns - ------- - Value of the gaussian kernel for each set of (x, x0) - """ - - # Get the sigma of each gaussian - sig = self.fct_sig(x0) - return gaussians(x, x0, sig) + # Normalize + kernel = kernel / np.nansum(kernel, axis=0) + # Convert to a sparse matrix. + return _sparse_c(kernel, len(grid), a) -# ============================================================================== -# Code for doing Tikhonov regularisation. -# ============================================================================== - -def finite_diff(x): +def _finite_diff(x): """Returns the finite difference matrix operator based on x. Parameters @@ -1855,48 +1310,11 @@ def finite_diff(x): the result is the same as np.diff(x) """ n_x = len(x) - - # Build matrix diff_matrix = diags([-1.], shape=(n_x - 1, n_x)) diff_matrix += diags([1.], 1, shape=(n_x - 1, n_x)) - return diff_matrix -def finite_second_d(grid): - """Returns the second derivative operator based on grid - - Parameters - ---------- - grid : array[float] - 1D array where the second derivative will be computed. - - Returns - ------- - second_d : array[float] - Operator to compute the second derivative, so that - f" = second_d.dot(f), where f is a function - projected on `grid`. - """ - - # Finite difference operator - d_matrix = finite_diff(grid) - - # Delta lambda - d_grid = d_matrix.dot(grid) - - # First derivative operator - first_d = diags(1. / d_grid).dot(d_matrix) - - # Second derivative operator - second_d = finite_diff(grid[:-1]).dot(first_d) - - # don't forget the delta lambda - second_d = diags(1. / d_grid[:-1]).dot(second_d) - - return second_d - - def finite_first_d(grid): """Returns the first derivative operator based on grid @@ -1908,101 +1326,22 @@ def finite_first_d(grid): Returns ------- first_d : array[float] - Operator to compute the second derivative, so that + Operator to compute the first derivative, so that f' = first_d.dot(f), where f is a function projected on `grid`. """ # Finite difference operator - d_matrix = finite_diff(grid) + d_matrix = _finite_diff(grid) # Delta lambda d_grid = d_matrix.dot(grid) # First derivative operator - first_d = diags(1. / d_grid).dot(d_matrix) - - return first_d - - -def get_tikho_matrix(grid, n_derivative=1, d_grid=True, estimate=None, pwr_law=0): - """Wrapper to return the tikhonov matrix given a grid and the derivative degree. - - Parameters - ---------- - grid : array[float] - 1D grid where the Tikhonov matrix is projected - n_derivative : int, optional - Degree of derivative. Possible values are 1 or 2 - d_grid : bool, optional - Whether to divide the differential operator by the grid differences, - which corresponds to an actual approximation of the derivative or not. - estimate : callable (preferably scipy.interpolate.UnivariateSpline), optional - Estimate of the solution on which the tikhonov matrix is applied. - Must be a function of `grid`. If UnivariateSpline, then the derivatives - are given directly (so best option), otherwise the tikhonov matrix will be - applied to `estimate(grid)`. Note that it is better to use `d_grid=True` - pwr_law: float, optional - Power law applied to the scale differentiated estimate, so the estimate - of tikhonov_matrix.dot(solution). It will be applied as follows: - norm_factor * scale_factor.dot(tikhonov_matrix) - where scale_factor = 1/(estimate_derivative)**pwr_law - and norm_factor = 1/sum(scale_factor) - Returns - ------- - t_mat : array[float] - The tikhonov matrix. - """ - if d_grid: - input_grid = grid - else: - input_grid = np.arange(len(grid)) - - if n_derivative == 1: - t_mat = finite_first_d(input_grid) - elif n_derivative == 2: - t_mat = finite_second_d(input_grid) - else: - msg = "`n_derivative` must be 1 or 2." - log.critical(msg) - raise ValueError(msg) - - if estimate is not None: - if hasattr(estimate, 'derivative'): - # Get the derivatives directly from the spline - if n_derivative == 1: - derivative = estimate.derivative(n=n_derivative) - tikho_factor_scale = derivative(grid[:-1]) - elif n_derivative == 2: - derivative = estimate.derivative(n=n_derivative) - tikho_factor_scale = derivative(grid[1:-1]) - else: - # Apply tikho matrix on estimate - tikho_factor_scale = t_mat.dot(estimate(grid)) + return diags(1. / d_grid).dot(d_matrix) - # Make sure all positive - tikho_factor_scale = np.abs(tikho_factor_scale) - # Apply power law - # (similar to 'kunasz1973'?) - tikho_factor_scale = np.power(tikho_factor_scale, -pwr_law) - # Normalize - valid = np.isfinite(tikho_factor_scale) - tikho_factor_scale /= np.sum(tikho_factor_scale[valid]) - # If some values are not finite, set to the max value - # so it will be more regularized - valid = np.isfinite(tikho_factor_scale) - if not valid.all(): - value = np.max(tikho_factor_scale[valid]) - tikho_factor_scale[~valid] = value - - # Apply to tikhonov matrix - t_mat = diags(tikho_factor_scale).dot(t_mat) - - return t_mat - - -def curvature_finite(factors, log_reg2, log_chi2): +def _curvature_finite(factors, log_reg2, log_chi2): """Compute the curvature in log space using finite differences Parameters @@ -2019,15 +1358,15 @@ def curvature_finite(factors, log_reg2, log_chi2): factors : array[float] Sorted and cut version of input factors array. curvature : array[float] - + Second derivative of the log10 of the regularized chi2 """ # Make sure it is sorted according to the factors idx = np.argsort(factors) factors, log_chi2, log_reg2 = factors[idx], log_chi2[idx], log_reg2[idx] # Get first and second derivatives - chi2_deriv = get_finite_derivatives(factors, log_chi2) - reg2_deriv = get_finite_derivatives(factors, log_reg2) + chi2_deriv = _get_finite_derivatives(factors, log_chi2) + reg2_deriv = _get_finite_derivatives(factors, log_reg2) # Compute the curvature according to Hansen 2001 # @@ -2046,7 +1385,7 @@ def curvature_finite(factors, log_reg2, log_chi2): return factors, curv -def get_finite_derivatives(x_array, y_array): +def _get_finite_derivatives(x_array, y_array): """ Compute first and second finite derivatives Parameters ---------- @@ -2099,12 +1438,10 @@ def _get_interp_idx_array(idx, relative_range, max_length): abs_range[-1] = np.min([abs_range[-1], max_length]) # Convert to slice - out = np.arange(*abs_range, 1) + return np.arange(*abs_range, 1) - return out - -def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): +def _minimize_on_grid(factors, val_to_minimize, interpolate=True, interp_index=None): """ Find minimum of a grid using akima spline interpolation to get a finer estimate Parameters @@ -2124,7 +1461,6 @@ def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): min_fac : float The factor with minimized error/curvature. """ - if interp_index is None: interp_index = [-2, 4] @@ -2169,7 +1505,7 @@ def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): return min_fac -def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): +def _find_intersect(factors, y_val, thresh, interpolate=True, search_range=None): """ Find the root of y_val - thresh (so the intersection between thresh and y_val) Parameters ---------- @@ -2180,12 +1516,11 @@ def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): thresh: float Threshold use in 'd_chi2' mode. Find the highest factor where the derivative of the chi2 derivative is below thresh. - interpolate: bool, optional + interpolate: bool, optional, default True. If True, use interpolation to find a finer minimum; otherwise, return minimum value in array. - search_range : iterable[int], optional + search_range : iterable[int], optional, default [0,3] Relative range of grid indices around the value to interpolate. - If not specified, defaults to [0,3]. Returns ------- @@ -2193,7 +1528,6 @@ def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): Factor corresponding to the best approximation of the intersection point. """ - if search_range is None: search_range = [0, 3] @@ -2247,61 +1581,46 @@ def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): return best_val -def soft_l1(z): +def _soft_l1(z): return 2 * ((1 + z)**0.5 - 1) -def cauchy(z): +def _cauchy(z): return np.log(1 + z) -def linear(z): +def _linear(z): return z -LOSS_FUNCTIONS = {'soft_l1': soft_l1, 'cauchy': cauchy, 'linear': linear} - +LOSS_FUNCTIONS = {'soft_l1': _soft_l1, 'cauchy': _cauchy, 'linear': _linear} +DEFAULT_THRESH_DERIVATIVE = {'chi2':1e-5, + 'chi2_soft_l1':1e-4, + 'chi2_cauchy':1e-3} class TikhoTests(dict): """ Class to save Tikhonov tests for different factors. All the tests are stored in the attribute `tests` as a dictionary - - Parameters - ---------- - test_dict : dict - Dictionary holding arrays for `factors`, `solution`, `error`, and `reg` - by default. """ - DEFAULT_TRESH_DERIVATIVE = (('chi2', 1e-5), - ('chi2_soft_l1', 1e-4), - ('chi2_cauchy', 1e-3)) - - def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): + def __init__(self, test_dict, default_chi2='chi2_cauchy'): """ Parameters ---------- test_dict : dict - Dictionary holding arrays for `factors`, `solution`, `error`, and `reg` - by default. + Dictionary holding arrays for `factors`, `solution`, `error`, `reg`, and `grid`. default_chi2: string Type of chi2 loss used by default. Options are chi2, chi2_soft_l1, chi2_cauchy. """ # Define the number of data points # (length of the "b" vector in the tikhonov regularisation) - if test_dict is None: - print('Unable to get the number of data points. Setting `n_points` to 1') - n_points = 1 - else: - n_points = len(test_dict['error'][0].squeeze()) + n_points = len(test_dict['error'][0].squeeze()) # Save attributes self.n_points = n_points self.default_chi2 = default_chi2 - self.default_thresh = {chi2_type: thresh - for (chi2_type, thresh) - in self.DEFAULT_TRESH_DERIVATIVE} + self.default_thresh = DEFAULT_THRESH_DERIVATIVE # Initialize so it behaves like a dictionary super().__init__(test_dict) @@ -2314,56 +1633,40 @@ def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): # Save the chi2 self[chi2_type] except KeyError: - self[chi2_type] = self.compute_chi2(loss=loss) -# # Save different loss function for chi2 -# self['chi2_soft_l1'] = self.compute_chi2(loss='soft_l1') -# self['chi2_cauchy'] = self.compute_chi2(loss='cauchy') + self[chi2_type] = self._compute_chi2(loss) - def compute_chi2(self, tests=None, n_points=None, loss='linear'): - """ Calculates the reduced chi squared statistic + + def _compute_chi2(self, loss): + """ + Calculates the reduced chi squared statistic Parameters ---------- - tests : dict, optional - Dictionary from which we take the error array; if not provided, - self is used - n_points : int, optional - Number of data points; if not provided, self.n_points is used + loss: str + Type of loss function to use. Options are 'linear', 'soft_l1', 'cauchy'. Returns ------- float Sum of the squared error array divided by the number of data points """ - # If not given, take the tests from the object - if tests is None: - tests = self - - # Get the loss function - if isinstance(loss, str): - try: - loss = LOSS_FUNCTIONS[loss] - except KeyError as e: - keys = [key for key in LOSS_FUNCTIONS.keys()] - msg = f'loss={loss} not a valid key. Must be one of {keys} or callable.' - raise e(msg) - elif not callable(loss): - raise ValueError('Invalid value for loss.') + # retrieve loss function + try: + loss = LOSS_FUNCTIONS[loss] + except KeyError as e: + msg = (f"loss={loss} not a valid key." + f"Must be one of {[LOSS_FUNCTIONS.keys()]} or callable.") + raise e(msg) # Compute the reduced chi^2 for all tests - chi2 = np.nanmean(loss(tests['error']**2), axis=-1) + chi2 = np.nanmean(loss(self['error']**2), axis=-1) # Remove residual dimensions - chi2 = chi2.squeeze() + return chi2.squeeze() - return chi2 - def get_chi2_derivative(self, key=None): - """ Compute derivative of the chi2 with respect to log10(factors) - - Parameters - ---------- - key: str - which chi2 is used for computations. Default is self.default_chi2. + def _get_chi2_derivative(self): + """ + Compute derivative of the chi2 with respect to log10(factors) Returns ------- @@ -2372,8 +1675,7 @@ def get_chi2_derivative(self, key=None): d_chi2 : array[float] derivative of chi squared array with respect to log10(factors) """ - if key is None: - key = self.default_chi2 + key = self.default_chi2 # Compute finite derivative fac_log = np.log10(self['factors']) @@ -2385,86 +1687,64 @@ def get_chi2_derivative(self, key=None): return factors_leftd, d_chi2 - def compute_curvature(self, tests=None, key=None): - - if key is None: - key = self.default_chi2 - # If not given, take the tests from the object - if tests is None: - tests = self + def _compute_curvature(self): + """ + TODO: add docstring + """ + key = self.default_chi2 # Compute the curvature... # Get the norm-2 of the regularisation term - reg2 = np.nansum(tests['reg'] ** 2, axis=-1) + reg2 = np.nansum(self['reg'] ** 2, axis=-1) - factors, curv = curvature_finite(tests['factors'], + return _curvature_finite(self['factors'], np.log10(self[key]), np.log10(reg2)) - return factors, curv - def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, - mode='curvature', key=None, thresh=None): - """Compute the best scale factor for Tikhonov regularisation. + def best_factor(self, mode='curvature'): + """ + Compute the best scale factor for Tikhonov regularisation. It is determined by taking the factor giving the highest logL on the detector or the highest curvature of the l-curve, depending on the chosen mode. + Parameters ---------- - tests : dictionary, optional - Results of tikhonov extraction tests for different factors. - Must have the keys "factors" and "-logl". If not specified, - the tests from self.tikho.tests are used. - interpolate : bool, optional - If True, use spline interpolation to find a finer minimum. - Default is true. - interp_index : list, optional - Relative range of grid indices around the minimum value to - interpolate across. If not specified, defaults to [-2,4]. mode : string How to find the best factor: 'chi2', 'curvature' or 'd_chi2'. - thresh : float - Threshold for use in 'd_chi2' mode. Find the highest factor where - the derivative of the chi2 derivative is below thresh. Returns ------- float Best scale factor as determined by the selected algorithm """ - if key is None: - key = self.default_chi2 - - if thresh is None: - thresh = self.default_thresh[key] - - # Use pre-run tests if not specified - if tests is None: - tests = self + key = self.default_chi2 + thresh = self.default_thresh[key] # Number of factors - n_fac = len(tests['factors']) + n_fac = len(self['factors']) # Determine the mode (what do we minimize?) if mode == 'curvature' and n_fac > 2: # Compute the curvature - factors, curv = tests.compute_curvature() + factors, curv = self._compute_curvature() # Find min factor - best_fac = _minimize_on_grid(factors, curv, interpolate, interp_index) + best_fac = _minimize_on_grid(factors, curv) elif mode == 'chi2': # Simply take the chi2 and factors - factors = tests['factors'] - y_val = tests[key] + factors = self['factors'] + y_val = self[key] # Find min factor - best_fac = _minimize_on_grid(factors, y_val, interpolate, interp_index) + best_fac = _minimize_on_grid(factors, y_val) elif mode == 'd_chi2' and n_fac > 1: # Compute the derivative of the chi2 - factors, y_val = tests.get_chi2_derivative() + factors, y_val = self._get_chi2_derivative() # Remove values for the higher factors that # are not already below thresh. If not _find_intersect @@ -2486,17 +1766,18 @@ def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, idx = slice(None) # Find intersection with threshold - best_fac = _find_intersect(factors[idx], y_val[idx], thresh, interpolate, interp_index) + best_fac = _find_intersect(factors[idx], y_val[idx], thresh) elif mode in ['curvature', 'd_chi2', 'chi2']: - best_fac = np.max(tests['factors']) - msg = (f'Could not compute {mode} because number of factor={n_fac}. ' + best_fac = np.max(self['factors']) + msg = (f'Could not compute {mode} because number of factors {n_fac} ' + 'is too small for that mode.' f'Setting best factor to max factor: {best_fac:.5e}') log.warning(msg) else: msg = (f'`mode`={mode} is not a valid option for ' - f'TikhoTests.best_tikho_factor().') + f'TikhoTests.best_factor().') log.critical(msg) raise ValueError(msg) @@ -2504,6 +1785,18 @@ def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, return best_fac +def try_solve_two_methods(matrix, result): + """on rare occasions spsolve's approximation of the matrix is not appropriate + and fails on good input data. revert to different solver""" + with warnings.catch_warnings(): + warnings.filterwarnings(action='error', category=MatrixRankWarning) + try: + return spsolve(matrix, result) + except MatrixRankWarning: + log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.') + return lsqr(matrix, result)[0] + + class Tikhonov: """ Tikhonov regularization to solve the ill-posed problem A.x = b, where @@ -2513,7 +1806,7 @@ class Tikhonov: where gamma is the Tikhonov regularization matrix. """ - def __init__(self, a_mat, b_vec, t_mat, valid=True): + def __init__(self, a_mat, b_vec, t_mat): """ Parameters ---------- @@ -2523,9 +1816,6 @@ def __init__(self, a_mat, b_vec, t_mat, valid=True): vector b in the system to solve A.x = b t_mat : matrix-like object (2d) Tikhonov regularisation matrix to be applied on b_vec. - valid : bool, optional - If True, solve the system only for valid indices. The - invalid values will be set to np.nan. Default is True. """ # Save input matrix @@ -2534,22 +1824,14 @@ def __init__(self, a_mat, b_vec, t_mat, valid=True): self.t_mat = t_mat # Pre-compute some matrix for the linear system to solve - t_mat_2 = (t_mat.T).dot(t_mat) # squared tikhonov matrix - a_mat_2 = a_mat.T.dot(a_mat) # squared model matrix - result = (a_mat.T).dot(b_vec.T) - idx_valid = (result.toarray() != 0).squeeze() # valid indices to use if `valid` is True - - # Save pre-computed matrix - self.t_mat_2 = t_mat_2 - self.a_mat_2 = a_mat_2 - self.result = result - self.idx_valid = idx_valid + self.t_mat_2 = (t_mat.T).dot(t_mat) # squared tikhonov matrix + self.a_mat_2 = a_mat.T.dot(a_mat) # squared model matrix + self.result = (a_mat.T).dot(b_vec.T) + self.idx_valid = (self.result.toarray() != 0).squeeze() # valid indices to use if `valid` is True # Save other attributes - self.valid = valid self.test = None - return def solve(self, factor=1.0): """ @@ -2571,8 +1853,7 @@ def solve(self, factor=1.0): a_mat_2 = self.a_mat_2 result = self.result t_mat_2 = self.t_mat_2 - valid = self.valid - idx_valid = self.idx_valid + idx = self.idx_valid # Matrix gamma squared (with scale factor) gamma_2 = factor ** 2 * t_mat_2 @@ -2583,19 +1864,14 @@ def solve(self, factor=1.0): # Initialize solution solution = np.full(matrix.shape[0], np.nan) - # Consider only valid indices if in valid mode - if valid: - idx = idx_valid - else: - idx = np.full(len(solution), True) - # Solve matrix = matrix[idx, :][:, idx] result = result[idx] - solution[idx] = spsolve(matrix, result) + solution[idx] = try_solve_two_methods(matrix, result) return solution + def test_factors(self, factors): """ Test multiple factors @@ -2628,7 +1904,9 @@ def test_factors(self, factors): sln.append(self.solve(factor)) # Save error A.x - b - err.append(a_mat.dot(sln[-1]) - b_vec) + this_err = a_mat.dot(sln[-1]) - b_vec + # initially this is a np.matrix of shape (1, n_pixels); flatten and make array + err.append(np.array(this_err).flatten()) # Save regularization term reg_i = t_mat.dot(sln[-1]) @@ -2638,7 +1916,7 @@ def test_factors(self, factors): message = '{}/{}'.format(i_fac, len(factors)) log.info(message) - # Final print + # Final message output message = '{}/{}'.format(i_fac + 1, len(factors)) log.info(message) @@ -2648,10 +1926,7 @@ def test_factors(self, factors): reg = np.array(reg) # Save in a dictionary - - tests = TikhoTests({'factors': factors, + return TikhoTests({'factors': factors, 'solution': sln, 'error': err, 'reg': reg}) - - return tests diff --git a/jwst/extract_1d/soss_extract/pastasoss.py b/jwst/extract_1d/soss_extract/pastasoss.py index 4030df2224..1c6a9f0b60 100644 --- a/jwst/extract_1d/soss_extract/pastasoss.py +++ b/jwst/extract_1d/soss_extract/pastasoss.py @@ -14,7 +14,7 @@ WAVEMAP_NWL = 5001 SUBARRAY_YMIN = 2048 - 256 -def get_wavelengths(refmodel, x, pwcpos, order): +def _get_wavelengths(refmodel, x, pwcpos, order): """Get the associated wavelength values for a given spectral order""" if order == 1: wavelengths = wavecal_model_order1_poly(refmodel, x, pwcpos) @@ -24,7 +24,7 @@ def get_wavelengths(refmodel, x, pwcpos, order): return wavelengths -def min_max_scaler(x, x_min, x_max): +def _min_max_scaler(x, x_min, x_max): """ Apply min-max scaling to input values. @@ -71,7 +71,7 @@ def wavecal_model_order1_poly(refmodel, x, pwcpos): to rotate the model """ x_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[0].scale_extents[0][0], "x_max": refmodel.wavecal_models[0].scale_extents[1][0], @@ -79,7 +79,7 @@ def wavecal_model_order1_poly(refmodel, x, pwcpos): ) pwcpos_offset_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[0].scale_extents[0][1], "x_max": refmodel.wavecal_models[0].scale_extents[1][1], @@ -148,7 +148,7 @@ def wavecal_model_order2_poly(refmodel, x, pwcpos): to rotate the model """ x_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[1].scale_extents[0][0], "x_max": refmodel.wavecal_models[1].scale_extents[1][0], @@ -156,7 +156,7 @@ def wavecal_model_order2_poly(refmodel, x, pwcpos): ) pwcpos_offset_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[1].scale_extents[0][1], "x_max": refmodel.wavecal_models[1].scale_extents[1][1], @@ -196,7 +196,7 @@ def get_poly_features(x, offset): return wavelengths -def rotate(x, y, angle, origin=(0, 0), interp=True): +def _rotate(x, y, angle, origin=(0, 0)): """ Applies a rotation transformation to a set of 2D points. @@ -210,9 +210,6 @@ def rotate(x, y, angle, origin=(0, 0), interp=True): The angle (in degrees) by which to rotate the points. origin : Tuple[float, float], optional The point about which to rotate the points. Default is (0, 0). - interp : bool, optional - Whether to interpolate the rotated positions onto the original x-pixel - column values. Default is True. Returns ------- @@ -223,7 +220,7 @@ def rotate(x, y, angle, origin=(0, 0), interp=True): -------- >>> x = np.array([0, 1, 2, 3]) >>> y = np.array([0, 1, 2, 3]) - >>> x_rot, y_rot = rotate(x, y, 90) + >>> x_rot, y_rot = _rotate(x, y, 90) """ # shift to rotate about center @@ -238,20 +235,18 @@ def rotate(x, y, angle, origin=(0, 0), interp=True): # apply transformation x_new, y_new = R @ (xy - xy_center) + xy_center - # interpolate rotated positions onto x-pixel column values (default) - if interp: - # interpolate new coordinates onto original x values and mask values - # outside of the domain of the image 0<=x<=2047 and 0<=y<=255. - y_new = interp1d(x_new, y_new, fill_value="extrapolate")(x) - mask = np.where(y_new <= 255.0) - x = x[mask] - y_new = y_new[mask] - return x, y_new + # interpolate rotated positions onto x-pixel column values + # interpolate new coordinates onto original x values and mask values + # outside of the domain of the image 0<=x<=2047 and 0<=y<=255. + y_new = interp1d(x_new, y_new, fill_value="extrapolate")(x) + mask = np.where(y_new <= 255.0) + x = x[mask] + y_new = y_new[mask] + return x, y_new - return x_new, y_new -def find_spectral_order_index(refmodel, order): +def _find_spectral_order_index(refmodel, order): """Return index of trace and wavecal dict corresponding to order Parameters @@ -268,6 +263,10 @@ def find_spectral_order_index(refmodel, order): The index to provide the reference file lists of traces and wavecal models to retrieve the arrays for the desired spectral order """ + if order not in [1,2]: + error_message = f"Order {order} is not supported at this time." + log.error(error_message) + raise ValueError(error_message) for i, entry in enumerate(refmodel.traces): if entry.spectral_order == order: @@ -277,7 +276,7 @@ def find_spectral_order_index(refmodel, order): return -1 -def get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): +def _get_soss_traces(refmodel, pwcpos, order, subarray): """Generate the traces given a pupil wheel position. This is the primary method for generating the gr700xd trace position given a @@ -295,14 +294,11 @@ def get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): pwcpos : float The pupil wheel positions angle provided in the FITS header under keyword PWCPOS. - order : str + order : str or int The spectral order for which a trace is computed. Order 3 is currently unsupported. subarray : str Name of subarray in use, typically 'SUBSTRIP96' or 'SUBSTRIP256'. - interp : bool, optional - Whether to interpolate the rotated positions onto the original x-pixel - column values. Default is True. Returns ------- @@ -311,39 +307,31 @@ def get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): points for the first spectral order. If `order` is '2', a tuple of the x and y coordinates of the rotated points for the second spectral order. - If `order` is '3' or a combination of '1', '2', and '3', a list of - tuples of the x and y coordinates of the rotated points for each - spectral order. Raises ------ ValueError - If `order` is not '1', '2', '3', or a combination of '1', '2', and '3'. + If `order` is not in ['1', '2']. """ - spectral_order_index = find_spectral_order_index(refmodel, int(order)) + spectral_order_index = _find_spectral_order_index(refmodel, int(order)) - if spectral_order_index < 0: - error_message = f"Order {order} is not supported at this time." - log.error(error_message) - raise ValueError(error_message) - else: - # reference trace data - x, y = refmodel.traces[spectral_order_index].trace.T.copy() - origin = refmodel.traces[spectral_order_index].pivot_x, refmodel.traces[spectral_order_index].pivot_y + # reference trace data + x, y = refmodel.traces[spectral_order_index].trace.T.copy() + origin = refmodel.traces[spectral_order_index].pivot_x, refmodel.traces[spectral_order_index].pivot_y - # Offset for SUBSTRIP96 - if subarray == 'SUBSTRIP96': - y -= 10 - # rotated reference trace - x_new, y_new = rotate(x, y, pwcpos - refmodel.meta.pwcpos_cmd, origin, interp=interp) + # Offset for SUBSTRIP96 + if subarray == 'SUBSTRIP96': + y -= 10 + # rotated reference trace + x_new, y_new = _rotate(x, y, pwcpos - refmodel.meta.pwcpos_cmd, origin) - # wavelength associated to trace at given pwcpos value - wavelengths = get_wavelengths(refmodel, x_new, pwcpos, int(order)) + # wavelength associated to trace at given pwcpos value + wavelengths = _get_wavelengths(refmodel, x_new, pwcpos, int(order)) - return order, x_new, y_new, wavelengths + return order, x_new, y_new, wavelengths -def extrapolate_to_wavegrid(w_grid, wavelength, quantity): +def _extrapolate_to_wavegrid(w_grid, wavelength, quantity): """ Extrapolates quantities on the right and the left of a given array of quantity @@ -361,9 +349,9 @@ def extrapolate_to_wavegrid(w_grid, wavelength, quantity): Array The interpolated quantities """ - sorted = np.argsort(wavelength) - q = quantity[sorted] - w = wavelength[sorted] + sort_i = np.argsort(wavelength) + q = quantity[sort_i] + w = wavelength[sort_i] # Determine the slope on the right of the array slope_right = (q[-1] - q[-2]) / (w[-1] - w[-2]) @@ -380,12 +368,10 @@ def extrapolate_to_wavegrid(w_grid, wavelength, quantity): q = np.concatenate((q_left, q, q_right)) # resample at the w_grid everywhere - q_grid = np.interp(w_grid, w, q) + return np.interp(w_grid, w, q) - return q_grid - -def calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=0, maxiter=5, dtol=1e-2): +def _calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=0, maxiter=5, dtol=1e-2): """Compute the 2D wavelength map on the detector. Parameters @@ -483,8 +469,8 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec Array, Array The 2D wavemaps and corresponding 1D spectraces """ - _, order1_x, order1_y, order1_wl = get_soss_traces(refmodel, pwcpos, order='1', subarray=subarray, interp=True) - _, order2_x, order2_y, order2_wl = get_soss_traces(refmodel, pwcpos, order='2', subarray=subarray, interp=True) + _, order1_x, order1_y, order1_wl = _get_soss_traces(refmodel, pwcpos, order='1', subarray=subarray) + _, order2_x, order2_y, order2_wl = _get_soss_traces(refmodel, pwcpos, order='2', subarray=subarray) # Make wavemap from trace center wavelengths, padding to shape (296, 2088) wavemin = WAVEMAP_WLMIN @@ -493,8 +479,8 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec wave_grid = np.linspace(wavemin, wavemax, nwave) # Extrapolate wavelengths for order 1 trace - xtrace_order1 = extrapolate_to_wavegrid(wave_grid, order1_wl, order1_x) - ytrace_order1 = extrapolate_to_wavegrid(wave_grid, order1_wl, order1_y) + xtrace_order1 = _extrapolate_to_wavegrid(wave_grid, order1_wl, order1_x) + ytrace_order1 = _extrapolate_to_wavegrid(wave_grid, order1_wl, order1_y) spectrace_1 = np.array([xtrace_order1, ytrace_order1, wave_grid]) # Set cutoff for order 2 where it runs off the detector @@ -517,15 +503,15 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec y_o2[o2_cutoff:] = y_o2[o2_cutoff - 1] + m * dx # Extrapolate wavelengths for order 2 trace - xtrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, x_o2) - ytrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, y_o2) + xtrace_order2 = _extrapolate_to_wavegrid(wave_grid, w_o2, x_o2) + ytrace_order2 = _extrapolate_to_wavegrid(wave_grid, w_o2, y_o2) spectrace_2 = np.array([xtrace_order2, ytrace_order2, wave_grid]) # Make wavemap from wavelength solution for order 1 - wavemap_1 = calc_2d_wave_map(wave_grid, xtrace_order1, ytrace_order1, np.zeros_like(xtrace_order1), oversample=1, padding=padsize) + wavemap_1 = _calc_2d_wave_map(wave_grid, xtrace_order1, ytrace_order1, np.zeros_like(xtrace_order1), oversample=1, padding=padsize) # Make wavemap from wavelength solution for order 2 - wavemap_2 = calc_2d_wave_map(wave_grid, xtrace_order2, ytrace_order2, np.zeros_like(xtrace_order2), oversample=1, padding=padsize) + wavemap_2 = _calc_2d_wave_map(wave_grid, xtrace_order2, ytrace_order2, np.zeros_like(xtrace_order2), oversample=1, padding=padsize) # Extrapolate wavemap to FULL frame wavemap_1[:SUBARRAY_YMIN - padsize, :] = wavemap_1[SUBARRAY_YMIN - padsize] @@ -546,6 +532,4 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec if spectraces: return np.array([wavemap_1, wavemap_2]), np.array([spectrace_1, spectrace_2]) - - else: - return np.array([wavemap_1, wavemap_2]) + return np.array([wavemap_1, wavemap_2]) diff --git a/jwst/extract_1d/soss_extract/soss_boxextract.py b/jwst/extract_1d/soss_extract/soss_boxextract.py index f499b21d7b..65d21124ab 100644 --- a/jwst/extract_1d/soss_extract/soss_boxextract.py +++ b/jwst/extract_1d/soss_extract/soss_boxextract.py @@ -5,8 +5,9 @@ log.setLevel(logging.DEBUG) -def get_box_weights(centroid, n_pix, shape, cols=None): - """ Return the weights of a box aperture given the centroid and the width of +def get_box_weights(centroid, n_pix, shape, cols): + """ + Return the weights of a box aperture given the centroid and the width of the box in pixels. All pixels will have the same weights except at the ends of the box aperture. @@ -27,12 +28,7 @@ def get_box_weights(centroid, n_pix, shape, cols=None): weights : array[float] An array of pixel weights to use with the box extraction. """ - - nrows, ncols = shape - - # Use all columns if not specified - if cols is None: - cols = np.arange(ncols) + nrows, _ = shape # Row centers of all pixels. rows = np.indices((nrows, len(cols)))[0] @@ -59,8 +55,9 @@ def get_box_weights(centroid, n_pix, shape, cols=None): return out -def box_extract(scidata, scierr, scimask, box_weights, cols=None): - """ Perform a box extraction. +def box_extract(scidata, scierr, scimask, box_weights): + """ + Perform a box extraction. Parameters ---------- @@ -73,8 +70,6 @@ def box_extract(scidata, scierr, scimask, box_weights, cols=None): box_weights : array[float] 2d array of pre-computed weights for box extraction, with same shape as scidata - cols : array[int] - 1d integer array of column numbers to extract Returns ------- @@ -85,12 +80,7 @@ def box_extract(scidata, scierr, scimask, box_weights, cols=None): flux_var : array[float] The variance of the flux in each column """ - - nrows, ncols = scidata.shape - - # Use all columns if not specified - if cols is None: - cols = np.arange(ncols) + cols = np.arange(scidata.shape[1]) # Keep only required columns and make a copy. data = scidata[:, cols].copy() @@ -151,6 +141,7 @@ def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): Map of the pixels where the uncertainty needs to be estimated. valid_pix : 2d array[bool] Map of valid pixels to be used to find the error empirically. + Returns ------- err_filled : 2d array[float] @@ -161,9 +152,6 @@ def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): err_valid = err[valid_pix] data_valid = data[valid_pix] - # - # Use np.searchsorted for efficiency - # # Need to sort the arrays used to find similar values idx_sort = np.argsort(data_valid) err_valid = err_valid[idx_sort] diff --git a/jwst/extract_1d/soss_extract/soss_centroids.py b/jwst/extract_1d/soss_extract/soss_centroids.py deleted file mode 100644 index 2474c571d0..0000000000 --- a/jwst/extract_1d/soss_extract/soss_centroids.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import numpy as np -import warnings - -from .soss_utils import robust_polyfit, get_image_dim - -log = logging.getLogger(__name__) -log.setLevel(logging.DEBUG) - - -def center_of_mass(column, ypos, halfwidth): - """Compute a windowed center-of-mass along a column. - - Parameters - ---------- - column : array[float] - The column values on which to compute the windowed center of mass. - ypos : float - The position along the column to center the window on. - halfwidth : int - The half-size of the window in pixels. - - Returns - -------- - ycom : float - The center-of-mass of the pixels within the window. - """ - - # Get the column shape and create a corresponding array of positions. - dimy, = column.shape - ypix = np.arange(dimy) - - # Find the indices of the window. - miny = int(np.fmax(np.around(ypos - halfwidth), 0)) - maxy = int(np.fmin(np.around(ypos + halfwidth + 1), dimy)) - - # Compute the center of mass on the window. - with np.errstate(invalid='ignore'): - ycom = (np.nansum(column[miny:maxy] * ypix[miny:maxy]) / - np.nansum(column[miny:maxy])) - - return ycom - - -def get_centroids_com(scidata_bkg, header=None, mask=None, poly_order=11): - """Determine the x, y coordinates of the trace using a center-of-mass - analysis. Works for either order if there is no contamination, or for - order 1 on a detector where the two orders are overlapping. - - Parameters - ---------- - scidata_bkg : array[float] - A background subtracted observation. - header : astropy.io.fits.Header - The header from one of the SOSS reference files. - mask : array[bool] - A boolean array of the same shape as image. Pixels corresponding to - True values will be masked. - poly_order : None or int - Order of the polynomial to fit to the extracted trace positions. - - Returns - -------- - xtrace : array[float] - The x coordinates of trace as computed from the best fit polynomial. - ytrace : array[float] - The y coordinates of trace as computed from the best fit polynomial. - param : array[float] - The best-fit polynomial parameters. - """ - - # If no mask was given use all pixels. - if mask is None: - mask = np.zeros_like(scidata_bkg, dtype='bool') - - # Call the script that determines the dimensions of the stack. - dimx, dimy, xos, yos, xnative, ynative, padding, refpix_mask = get_image_dim(scidata_bkg, header=header) - - # Replace masked pixel values with NaNs. - scidata_bkg_masked = np.where(mask | ~refpix_mask, np.nan, scidata_bkg) - - # Find centroid - first pass, use all pixels in the column. - - # Normalize each column - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN") - maxvals = np.nanmax(scidata_bkg_masked, axis=0) - scidata_norm = scidata_bkg_masked / maxvals - - # Create 2D Array of pixel positions. - xpix = np.arange(dimx) - ypix = np.arange(dimy) - _, ygrid = np.meshgrid(xpix, ypix) - - # CoM analysis to find initial positions using all rows. - with np.errstate(invalid='ignore'): - ytrace = np.nansum(scidata_norm * ygrid, axis=0) / np.nansum(scidata_norm, axis=0) - ytrace = np.where(np.abs(ytrace) == np.inf, np.nan, ytrace) - - # Second pass - use a windowed CoM at the previous position. - halfwidth = 30 * yos - for icol in range(dimx): - - ycom = center_of_mass(scidata_norm[:, icol], ytrace[icol], halfwidth) - - # If NaN was returned or centroid is out of bounds, we are done. - if not np.isfinite(ycom) or (ycom > (ynative - 1) * yos) or (ycom < 0): - ytrace[icol] = np.nan - continue - - # If the pixel at the centroid is below the local mean we are likely - # mid-way between orders and we should shift the window downward to - # get a reliable centroid for order 1. - irow = int(np.around(ycom)) - miny = int(np.fmax(np.around(ycom) - halfwidth, 0)) - maxy = int(np.fmin(np.around(ycom) + halfwidth + 1, dimy)) - if scidata_norm[irow, icol] < np.nanmean(scidata_norm[miny:maxy, icol]): - ycom = center_of_mass(scidata_norm[:, icol], ycom - halfwidth, halfwidth) - - # If the updated position is too close to the array edge, use NaN. - if not np.isfinite(ycom) or (ycom <= 5 * yos) or (ycom >= (ynative - 6) * yos): - ytrace[icol] = np.nan - continue - - # Update the position if the above checks were successful. - ytrace[icol] = ycom - - # Third pass - fine tuning using a smaller window. - halfwidth = 16 * yos - for icol in range(dimx): - - ytrace[icol] = center_of_mass(scidata_norm[:, icol], ytrace[icol], - halfwidth) - - # Fit the y-positions with a polynomial and use result as true y-positions. - xtrace = np.arange(dimx) - mask = np.isfinite(ytrace) - - # For padded arrays ignore padding for consistency with real data - if padding != 0: - mask = mask & (xtrace >= xos * padding) & (xtrace < (dimx - xos * padding)) - - # If no polynomial order was given return the raw measurements. - if poly_order is None: - param = [] - else: - param = robust_polyfit(xtrace[mask], ytrace[mask], poly_order) - ytrace = np.polyval(param, xtrace) - - return xtrace, ytrace, param diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 2f3ad35052..2ec8a8f41c 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -1,13 +1,14 @@ import logging import numpy as np + from scipy.interpolate import UnivariateSpline, CubicSpline from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels import dqflags, SossWaveGridModel -from ..extract import populate_time_keywords -from ...lib import pipe_utils +from jwst.extract_1d.extract import populate_time_keywords +from jwst.lib import pipe_utils from astropy.nddata.bitmask import bitfield_to_boolean_mask from .soss_syscor import make_background_mask, soss_background @@ -21,9 +22,13 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) +ORDER2_SHORT_CUTOFF = 0.58 + def get_ref_file_args(ref_files): - """Prepare the reference files for the extraction engine. + """ + Prepare the reference files for the extraction engine. + Parameters ---------- ref_files : dict @@ -44,9 +49,9 @@ def get_ref_file_args(ref_files): else: do_padding = False - (wavemap_o1, wavemap_o2), (spectrace_o1, spectrace_o2) = \ + (wavemap_o1, wavemap_o2) = \ get_soss_wavemaps(pastasoss_ref, pwcpos=ref_files['pwcpos'], subarray=ref_files['subarray'], - padding=do_padding, padsize=pad, spectraces=True) + padding=do_padding, padsize=pad, spectraces=False) # The spectral profiles for order 1 and 2. specprofile_ref = ref_files['specprofile'] @@ -76,7 +81,7 @@ def get_ref_file_args(ref_files): # The throughput curves for order 1 and 2. - throughput_index_dict = dict() + throughput_index_dict = {} for i, throughput in enumerate(pastasoss_ref.throughputs): throughput_index_dict[throughput.spectral_order] = i @@ -87,28 +92,29 @@ def get_ref_file_args(ref_files): # The spectral kernels. speckernel_ref = ref_files['speckernel'] - ovs = speckernel_ref.meta.spectral_oversampling n_pix = 2 * speckernel_ref.meta.halfwidth + 1 # Take the centroid of each trace as a grid to project the WebbKernel # WebbKer needs a 2d input, so artificially add axis wave_maps = [wavemap_o1, wavemap_o2] - centroid = dict() + centroid = {} for wv_map, order in zip(wave_maps, [1, 2]): - # Needs the same number of columns as the detector. Put zeros where not define. - wv_cent = np.zeros((1, wv_map.shape[1])) + wv_cent = np.zeros((wv_map.shape[1])) + # Get central wavelength as a function of columns - col, _, wv = get_trace_1d(ref_files, order) - wv_cent[:, col] = wv + col, _, wv = _get_trace_1d(ref_files, order) + wv_cent[col] = wv + # Set invalid values to zero idx_invalid = ~np.isfinite(wv_cent) wv_cent[idx_invalid] = 0.0 centroid[order] = wv_cent + # Get kernels - kernels_o1 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[1], ovs, n_pix) - kernels_o2 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[2], ovs, n_pix) + kernels_o1 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[1], n_pix) + kernels_o2 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[2], n_pix) - # Temporary(?) fix to make sure that the kernels can cover the wavelength maps + # Make sure that the kernels cover the wavelength maps speckernel_wv_range = [np.min(speckernel_ref.wavelengths), np.max(speckernel_ref.wavelengths)] valid_wavemap = (speckernel_wv_range[0] <= wavemap_o1) & (wavemap_o1 <= speckernel_wv_range[1]) wavemap_o1 = np.where(valid_wavemap, wavemap_o1, 0.) @@ -119,8 +125,9 @@ def get_ref_file_args(ref_files): [throughput_o1, throughput_o2], [kernels_o1, kernels_o2] -def get_trace_1d(ref_files, order): +def _get_trace_1d(ref_files, order): """Get the x, y, wavelength of the trace after applying the transform. + Parameters ---------- ref_files : dict @@ -167,8 +174,11 @@ def get_trace_1d(ref_files, order): return xtrace, ytrace, wavetrace -def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile, threshold=1e-4): +def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile, threshold=1e-4): """ + Roughly estimate the underlying flux of the target spectrum by simply masking + out order 2 and retrieving the flux from order 1. + Parameters ---------- scidata_bkg : array @@ -184,6 +194,7 @@ def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tra threshold : float, optional: The pixels with an aperture[order 2] > `threshold` are considered contaminated and will be masked. Default is 1e-4. + Returns ------- func @@ -193,49 +204,46 @@ def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tra # Unpack ref_file arguments wave_maps, spat_pros, thrpts, _ = ref_file_args - # Oversampling of 1 to make sure the solution will be stable - n_os = 1 - # Define wavelength grid based on order 1 only (so first index) - wave_grid = grid_from_map(wave_maps[0], spat_pros[0], n_os=n_os) + wave_grid = grid_from_map(wave_maps[0], spat_pros[0], n_os=1) # Mask parts contaminated by order 2 based on its spatial profile - mask = ((spat_pros[1] >= threshold) | mask_trace_profile | scimask) + mask = ((spat_pros[1] >= threshold) | mask_trace_profile[0]) # Init extraction without convolution kernel (so extract the spectrum at order 1 resolution) - ref_file_args = [wave_maps[0]], [spat_pros[0]], [thrpts[0]], [np.array([1.])] - kwargs = {'wave_grid': wave_grid, - 'orders': [1], - 'mask_trace_profile': [mask]} - engine = ExtractionEngine(*ref_file_args, **kwargs) + ref_file_args = [wave_maps[0]], [spat_pros[0]], [thrpts[0]], [None] + engine = ExtractionEngine(*ref_file_args, wave_grid, [mask], global_mask=scimask, orders=[1]) # Extract estimate - spec_estimate = engine.__call__(data=scidata_bkg, error=scierr) + spec_estimate = engine(scidata_bkg, scierr) # Interpolate idx = np.isfinite(spec_estimate) - estimate_spl = UnivariateSpline(wave_grid[idx], spec_estimate[idx], k=3, s=0, ext=0) - - return estimate_spl + return UnivariateSpline(wave_grid[idx], spec_estimate[idx], k=3, s=0, ext=0) -def get_native_grid_from_trace(ref_files, spectral_order): +def _get_native_grid_from_trace(ref_files, spectral_order): """ Make a 1d-grid of the pixels boundary and ready for ATOCA ExtractionEngine, based on the wavelength solution. + Parameters ---------- ref_files: dict A dictionary of the reference file DataModels. spectral_order: int The spectral order for which to return the trace parameters. + Returns ------- - Grid of the pixels boundaries at the native sampling (1d array) + wave : + Grid of the pixels boundaries at the native sampling (1d array) + col : + The column number of the pixel """ - # From wavelenght solution - col, _, wave = get_trace_1d(ref_files, spectral_order) + # From wavelength solution + col, _, wave = _get_trace_1d(ref_files, spectral_order) # Keep only valid solution ... idx_valid = np.isfinite(wave) @@ -256,22 +264,27 @@ def get_native_grid_from_trace(ref_files, spectral_order): return wave, col -def get_grid_from_trace(ref_files, spectral_order, n_os=1): +def _get_grid_from_trace(ref_files, spectral_order, n_os): """ Make a 1d-grid of the pixels boundary and ready for ATOCA ExtractionEngine, based on the wavelength solution. + Parameters ---------- ref_files: dict A dictionary of the reference file DataModels. spectral_order: int The spectral order for which to return the trace parameters. + n_os: int or array + The oversampling factor of the wavelength grid used when solving for + the uncontaminated flux. + Returns ------- Grid of the pixels boundaries at the native sampling (1d array) """ - wave, _ = get_native_grid_from_trace(ref_files, spectral_order) + wave, _ = _get_native_grid_from_trace(ref_files, spectral_order) # Use pixel boundaries instead of the center values wv_upper_bnd, wv_lower_bnd = get_wave_p_or_m(wave[None, :]) @@ -283,26 +296,44 @@ def get_grid_from_trace(ref_files, spectral_order, n_os=1): wave_grid = np.append(wv_lower_bnd, wv_upper_bnd[-1]) # Oversample as needed - wave_grid = oversample_grid(wave_grid, n_os=n_os) + return oversample_grid(wave_grid, n_os=n_os) - return wave_grid - -def make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os, wv_range=None): - ''' Create the grid use for the simultaneous extraction of order 1 and 2. +def _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os): + """ + Create the grid to use for the simultaneous extraction of order 1 and 2. The grid is made by: 1) requiring that it satisfies the oversampling n_os 2) trying to reach the specified tolerance for the spectral range shared between order 1 and 2 3) trying to reach the specified tolerance in the rest of spectral range The max_grid_size overrules steps 2) and 3), so the precision may not be reached if the grid size needed is too large. - ''' + + Parameters + ---------- + ref_files : dict + A dictionary of the reference file DataModels. + rtol : float + The relative tolerance needed on a pixel model. + max_grid_size : int + Maximum grid size allowed. + estimate : UnivariateSpline + Estimate of the target flux as a function of wavelength in microns. + n_os : int + The oversampling factor of the wavelength grid used when solving for + the uncontaminated flux. + + Returns + ------- + wave_grid : 1d array + The grid of the pixels boundaries at the native sampling. + """ # Build native grid for each orders. spectral_orders = [2, 1] - grids_ord = dict() + grids_ord = {} for sp_ord in spectral_orders: - grids_ord[sp_ord] = get_grid_from_trace(ref_files, sp_ord, n_os=n_os) + grids_ord[sp_ord] = _get_grid_from_trace(ref_files, sp_ord, n_os=n_os) # Build the list of grids given to make_combined_grid. # It must be ordered in increasing priority. @@ -317,39 +348,33 @@ def make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os, wv # And make grid list all_grids = [grids_ord[2][is_shared], grids_ord[1], grids_ord[2][~is_shared]] - # Set wavelength range if not given - if wv_range is None: - # Cut order 2 at 0.77 (not smaller than that) - # because there is no contamination there. Can be extracted afterward. - # In the red, no cut. - wv_range = [0.77, np.max(grids_ord[1])] + # Cut order 2 at 0.77 (not smaller than that) + # because there is no contamination there. Can be extracted afterward. + # In the red, no cut. + wv_range = [0.77, np.max(grids_ord[1])] # Finally, build the list of corresponding estimates. # The estimate for the overlapping part is the order 1 estimate. # There is no estimate yet for the blue part of order 2, so give a flat spectrum. def flat_fct(wv): return np.ones_like(wv) - all_estimates = [estimate, estimate, flat_fct] # Generate the combined grid - kwargs = dict(rtol=rtol, max_total_size=max_grid_size, max_iter=30, grid_range=wv_range) - combined_grid = make_combined_adaptive_grid(all_grids, all_estimates, **kwargs) - - return combined_grid - + kwargs = {"rtol":rtol, "max_total_size":max_grid_size, "max_iter":30} + return make_combined_adaptive_grid(all_grids, all_estimates, wv_range, **kwargs) -def append_tiktests(test_a, test_b): - out = dict() +def _append_tiktests(test_a, test_b): + out = {} for key in test_a: out[key] = np.append(test_a[key], test_b[key], axis=0) return out -def populate_tikho_attr(spec, tiktests, idx, sp_ord): +def _populate_tikho_attr(spec, tiktests, idx, sp_ord): spec.spectral_order = sp_ord spec.meta.soss_extract1d.type = 'TEST' @@ -360,11 +385,36 @@ def populate_tikho_attr(spec, tiktests, idx, sp_ord): spec.meta.soss_extract1d.factor = tiktests['factors'][idx] spec.int_num = 0 - return +def _f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord): + """ + Bin the flux to the pixel grid and build a SpecModel. + + Parameters + ---------- + f_order : np.array + The solution f_k of the linear system. + + grid_order : np.array + The wavelength grid of the solution, usually oversampled compared to the pixel grid. + + ref_file_args : list + The reference file arguments used by the ExtractionEngine. -def f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord=0): + pixel_grid : np.array + The pixel grid to which the flux should be binned. + mask : np.array + The mask of the pixels to be extracted. + + sp_ord : int + The spectral order of the flux. + + Returns + ------- + spec : SpecModel + + """ # Make sure the input is not modified ref_file_args = ref_file_args.copy() @@ -380,6 +430,7 @@ def f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord=0): pixel_grid = np.squeeze(pixel_grid) f_binned = np.squeeze(f_binned) + # Remove Nans to save space is_valid = np.isfinite(f_binned) table_size = np.sum(is_valid) @@ -423,10 +474,25 @@ def _build_tracemodel_order(engine, ref_file_args, f_k, i_order, mask, ref_files # Project on detector and save in dictionary tracemodel_ord = model.rebuild(flux_order, fill_value=np.nan) + # import matplotlib.pyplot as plt + # fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 12)) + # ax1.imshow(engine.trace_profile[0], origin='lower', aspect='auto', cmap='viridis') + # ax1.set_title("Trace Profile") + # ax2.plot(grid_order, flux_order) + # ax2.set_title("flux vs wl") + # cim = ax3.imshow(engine.wave_map[i_order], origin='lower', aspect='auto', cmap='viridis', vmin=0.7, vmax=0.9) + # ax3.set_title("wavelengths") + # fig.colorbar(cim, ax=ax3) + # ax4.imshow(tracemodel_ord, origin='lower', aspect='auto', cmap='viridis', vmin=0, vmax=80) + # ax4.set_title("trace model") + # for ax in [ax1, ax3, ax4]: + # ax.set_xlim([1100, 1400]) + # plt.show() + # Build 1d spectrum integrated over pixels - pixel_wave_grid, valid_cols = get_native_grid_from_trace(ref_files, sp_ord) - spec_ord = f_to_spec(flux_order, grid_order, ref_file_order, pixel_wave_grid, - np.all(mask, axis=0)[valid_cols], sp_ord=sp_ord) + pixel_wave_grid, valid_cols = _get_native_grid_from_trace(ref_files, sp_ord) + spec_ord = _f_to_spec(flux_order, grid_order, ref_file_order, pixel_wave_grid, + np.all(mask, axis=0)[valid_cols], sp_ord) return tracemodel_ord, spec_ord @@ -446,7 +512,7 @@ def _build_null_spec_table(wave_grid): Null SpecModel. Flux values are NaN, DQ flags are 1, but note that DQ gets overwritten at end of run_extract1d """ - wave_grid_cut = wave_grid[wave_grid > 0.58] # same cutoff applied for valid data + wave_grid_cut = wave_grid[wave_grid > ORDER2_SHORT_CUTOFF] spec = datamodels.SpecModel() spec.spectral_order = 2 spec.meta.soss_extract1d.type = 'OBSERVATION' @@ -460,7 +526,7 @@ def _build_null_spec_table(wave_grid): return spec -def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, +def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, tikfac=None, threshold=1e-4, n_os=2, wave_grid=None, estimate=None, rtol=1e-3, max_grid_size=1000000): """Perform the spectral extraction on a single image. @@ -491,20 +557,22 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, n_os : int, optional The oversampling factor of the wavelength grid used when solving for the uncontaminated flux. If not specified, defaults to 2. - wave_grid : str or SossWaveGridModel or None - Filename of reference file or SossWaveGridModel containing the wavelength grid used by ATOCA - to model each pixel valid pixel of the detector. If not given, the grid is determined - based on an estimate of the flux (estimate), the relative tolerance (rtol) - required on each pixel model and the maximum grid size (max_grid_size). + wave_grid : np.ndarray, optional + Wavelength grid used by ATOCA to model each pixel valid pixel of the detector. + If not given, the grid is determined based on an estimate of the flux (estimate), + the relative tolerance (rtol) required on each pixel model and + the maximum grid size (max_grid_size). + # TODO: none of the options specified on main work + # Should we add support for these? If not, is SossWaveGridModel used for anything, + # and can that be removed from stdatamodels and as a valid argument to soss_wave_grid_in? estimate : UnivariateSpline or None Estimate of the target flux as a function of wavelength in microns. rtol : float The relative tolerance needed on a pixel model. It is used to determine the sampling - of the soss_wave_grid when not directly given. Default is 1e-3. + of wave_grid when the input wave_grid is None. Default is 1e-3. max_grid_size : int - Maximum grid size allowed. It is used when soss_wave_grid is not directly - to make sure the computation time or the memory used stays reasonable. - Default is 1000000 + Maximum grid size allowed when wave_grid is None. + Default is 1000000. Returns ------- @@ -515,15 +583,13 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, logl : float Log likelihood value associated with the Tikhonov factor selected. wave_grid : 1d array - Same as wave_grid input + The wavelengths at which the spectra were extracted. Same as wave_grid + if specified as input. spec_list : list of SpecModel List of the underlying spectra for each integration and order. The tikhonov tests are also included. """ - # Init list of atoca 1d spectra - spec_list = [] - # Generate list of orders to simulate from pastasoss trace list order_list = [] for trace in ref_files['pastasoss'].traces: @@ -536,66 +602,60 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, scimask = scimask | ~(scierr > 0) # Define mask based on box aperture (we want to model each contaminated pixels that will be extracted) - mask_trace_profile = [~(box_weights[order] > 0) for order in order_list] + mask_trace_profile = [(~(box_weights[order] > 0)) | (refmask) for order in order_list] # Define mask of pixel to model (all pixels inside box aperture) - global_mask = np.all(mask_trace_profile, axis=0) | refmask + global_mask = np.all(mask_trace_profile, axis=0).astype(bool) # Rough estimate of the underlying flux - # Note: estim_flux func is not strictly necessary and factors could be a simple logspace - - # dq mask caused issues here and this may need a try/except wrap. - # Dev suggested np.logspace(-19, -10, 10) if (tikfac is None or wave_grid is None) and estimate is None: - estimate = estim_flux_first_order(scidata_bkg, scierr, scimask, - ref_file_args, mask_trace_profile[0]) + estimate = _estim_flux_first_order(scidata_bkg, scierr, scimask, + ref_file_args, mask_trace_profile) # Generate grid based on estimate if not given if wave_grid is None: log.info(f'wave_grid not given: generating grid based on rtol={rtol}') - wave_grid = make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os) - log.debug(f'wave_grid covering from {wave_grid.min()} to {wave_grid.max()}') + wave_grid = _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os) + log.debug(f'wave_grid covering from {wave_grid.min()} to {wave_grid.max()}' + f' with {wave_grid.size} points') else: log.info('Using previously computed or user specified wavelength grid.') - # Set the c_kwargs using the minimum value of the kernels - c_kwargs = [{'thresh': webb_ker.min_value} for webb_ker in ref_file_args[3]] - # Initialize the Engine. engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid, mask_trace_profile=mask_trace_profile, global_mask=scimask, - threshold=threshold, - c_kwargs=c_kwargs) + threshold=threshold) + spec_list = [] if tikfac is None: log.info('Solving for the optimal Tikhonov factor.') + save_tiktests = True # Find the tikhonov factor. # Initial pass 8 orders of magnitude with 10 grid points. guess_factor = engine.estimate_tikho_factors(estimate) log_guess = np.log10(guess_factor) factors = np.logspace(log_guess - 4, log_guess + 4, 10) - all_tests = engine.get_tikho_tests(factors, data=scidata_bkg, error=scierr) - tikfac, mode, _ = engine.best_tikho_factor(tests=all_tests, fit_mode='all') + all_tests = engine.get_tikho_tests(factors, scidata_bkg, scierr) + tikfac = engine.best_tikho_factor(all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) factors = np.logspace(tikfac - 2, tikfac + 2, 20) - tiktests = engine.get_tikho_tests(factors, data=scidata_bkg, error=scierr) - tikfac, mode, _ = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') - # Add all theses tests to previous ones - all_tests = append_tiktests(all_tests, tiktests) + tiktests = engine.get_tikho_tests(factors, scidata_bkg, scierr) + tikfac = engine.best_tikho_factor(tiktests, fit_mode='d_chi2') + all_tests = _append_tiktests(all_tests, tiktests) # Save spectra in a list of SingleSpecModels for optional output - save_tiktests = True for i_order, order in enumerate(order_list): for idx in range(len(all_tests['factors'])): f_k = all_tests['solution'][idx, :] args = (engine, ref_file_args, f_k, i_order, global_mask, ref_files) _, spec_ord = _build_tracemodel_order(*args) - populate_tikho_attr(spec_ord, all_tests, idx, i_order + 1) + _populate_tikho_attr(spec_ord, all_tests, idx, i_order + 1) spec_ord.meta.soss_extract1d.color_range = 'RED' # Add the result to spec_list @@ -606,17 +666,17 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, log.info('Using a Tikhonov factor of {}'.format(tikfac)) # Run the extract method of the Engine. - f_k = engine.__call__(data=scidata_bkg, error=scierr, tikhonov=True, factor=tikfac) + f_k = engine(scidata_bkg, scierr, tikhonov=True, factor=tikfac) # Compute the log-likelihood of the best fit. - logl = engine.compute_likelihood(f_k, same=False) + logl = engine.compute_likelihood(f_k, scidata_bkg, scierr) log.info('Optimal solution has a log-likelihood of {}'.format(logl)) # Create a new instance of the engine for evaluating the trace model. # This allows bad pixels and pixels below the threshold to be reconstructed as well. # Model the order 1 and order 2 trace separately. - tracemodels = dict() + tracemodels = {} for i_order, order in enumerate(order_list): @@ -634,9 +694,7 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Add the result to spec_list spec_list.append(spec_ord) - # ############################### - # Model remaining part of order 2 - # ############################### + # Model the remaining part of order 2 if ref_files['subarray'] != 'SUBSTRIP96': idx_order2 = 1 order = idx_order2 + 1 @@ -648,12 +706,9 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Mask for the fit. All valid pixels inside box aperture mask_fit = mask_trace_profile[idx_order2] | scimask -# # and extract only what was not already modeled -# already_modeled = np.isfinite(tracemodels[order_str]) -# mask_fit |= already_modeled # Build 1d spectrum integrated over pixels - pixel_wave_grid, valid_cols = get_native_grid_from_trace(ref_files, order) + pixel_wave_grid, valid_cols = _get_native_grid_from_trace(ref_files, order) # Hardcode wavelength highest boundary as well. # Must overlap with lower limit in make_decontamination_grid @@ -665,14 +720,14 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Model the remaining part of order 2 with atoca try: - model, spec_ord = model_single_order(scidata_bkg, scierr, ref_file_order, + model, spec_ord = _model_single_order(scidata_bkg, scierr, ref_file_order, mask_fit, global_mask, order, - pixel_wave_grid, valid_cols, save_tiktests, - tikfac_log_range=tikfac_log_range) + pixel_wave_grid, valid_cols, + tikfac_log_range, save_tiktests=save_tiktests) except MaskOverlapError: log.error('Not enough unmasked pixels to model the remaining part of order 2.' - 'Model and spectrum will be NaN in that spectral region.') + ' Model and spectrum will be NaN in that spectral region.') spec_ord = [_build_null_spec_table(pixel_wave_grid)] model = np.nan * np.ones_like(scidata_bkg) @@ -682,7 +737,9 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, model = np.where(already_modeled, 0., model) # Add to tracemodels + both_nan = np.isnan(tracemodels[order_str]) & np.isnan(model) tracemodels[order_str] = np.nansum([tracemodels[order_str], model], axis=0) + tracemodels[order_str][both_nan] = np.nan # Add the result to spec_list for sp in spec_ord: @@ -692,7 +749,8 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, return tracemodels, tikfac, logl, wave_grid, spec_list -def compute_box_weights(ref_files, shape, width=40.): +def _compute_box_weights(ref_files, shape, width): + """Determine the weights for the box extraction.""" # Generate list of orders from pastasoss trace list order_list = [] @@ -700,8 +758,7 @@ def compute_box_weights(ref_files, shape, width=40.): order_list.append(trace.spectral_order) # Extract each order from order list - box_weights = dict() - wavelengths = dict() + box_weights, wavelengths = {}, {} order_str = {order: f'Order {order}' for order in order_list} for order_integer in order_list: # Order string-name is used more often than integer-name @@ -710,13 +767,13 @@ def compute_box_weights(ref_files, shape, width=40.): log.debug(f'Compute box weights for {order}.') # Define the box aperture - xtrace, ytrace, wavelengths[order] = get_trace_1d(ref_files, order_integer) + xtrace, ytrace, wavelengths[order] = _get_trace_1d(ref_files, order_integer) box_weights[order] = get_box_weights(ytrace, width, shape, cols=xtrace) return box_weights, wavelengths -def decontaminate_image(scidata_bkg, tracemodels, subarray): +def _decontaminate_image(scidata_bkg, tracemodels, subarray): """Perform decontamination of the image based on the trace models""" # Which orders to extract. if subarray == 'SUBSTRIP96': @@ -730,7 +787,7 @@ def decontaminate_image(scidata_bkg, tracemodels, subarray): mod_order_list = tracemodels.keys() # Create dictionaries for the output images. - decontaminated_data = dict() + decontaminated_data = {} log.debug('Performing the decontamination.') @@ -753,71 +810,96 @@ def decontaminate_image(scidata_bkg, tracemodels, subarray): return decontaminated_data -# TODO Add docstring -def model_single_order(data_order, err_order, ref_file_args, mask_fit, - mask_rebuild, order, wave_grid, valid_cols, save_tiktests=False, tikfac_log_range=None): +def _model_single_order(data_order, err_order, ref_file_args, mask_fit, + mask_rebuild, order, wave_grid, valid_cols, + tikfac_log_range, save_tiktests=False): + """ + Extract an output spectrum for a single spectral order using the ATOCA + algorithm, testing a range of Tikhonov factors. + The Tikhonov factor is derived in two stages: first, ten factors are tested + spanning tikfac_log_range, and then a further 20 factors are tested across + 2 orders of magnitude in each direction around the best factor from the first + stage. + The best-fitting model and spectrum are reconstructed using the best-fit Tikhonov factor + and respecting mask_rebuild. + + Parameters + ---------- + data_order : np.array + The 2D data array for the spectral order to be extracted. + err_order : np.array + The 2D error array for the spectral order to be extracted. + ref_file_args : list + The reference file arguments used by the ExtractionEngine. + mask_fit : np.array + Mask determining the aperture used for extraction. This typically includes + detector bad pixels and any pixels that are not part of the trace + mask_rebuild : np.array + Mask determining the aperture used for rebuilding the trace. This typically includes + only pixels that do not belong to either spectral trace, i.e., regions of the detector + where no real data could exist. + order : int + The spectral order to be extracted. + wave_grid : np.array + The wavelength grid used to model the data. + valid_cols : np.array + The columns of the detector that are valid for extraction. + tikfac_log_range : list + The range of Tikhonov factors to test, in log space. + save_tiktests : bool, optional. + If True, save the intermediate models and spectra for each Tikhonov factor tested. + + Returns + ------- + model : np.array + Model derived from the best Tikhonov factor, same shape as data_order. + spec_list : list of SpecModel + If save_tiktests is True, returns a list of the model spectra for each Tikhonov factor tested, + with the best-fitting spectrum last in the list. + If save_tiktests is False, returns a one-element list with the best-fitting spectrum. + + Notes + ----- + The last spectrum in the list of SpecModels lacks the "chi2", "chi2_soft_l1", "chi2_cauchy", and "reg" + attributes, as these are only calculated for the intermediate models. The last spectrum is not + necessarily identical to any of the spectra in the list, as it is reconstructed according to + mask_rebuild instead of fit respecting mask_fit; that is, bad pixels are included. + + # TODO are all of these behaviors for the last spec in the list the desired ones? + """ # The throughput and kernel is not needed here; set them so they have no effect on the extraction. def throughput(wavelength): return np.ones_like(wavelength) kernel = np.array([1.]) - - # Set reference file arguments ref_file_args[2] = [throughput] ref_file_args[3] = [kernel] - # ########################### - # First, generate an estimate - # (only if the initial guess of tikhonov factor range is not given) - # ########################### - - if tikfac_log_range is None: - # Initialize the engine - engine = ExtractionEngine(*ref_file_args, - wave_grid=wave_grid, - orders=[order], - mask_trace_profile=[mask_fit]) - - # Extract estimate - spec_estimate = engine.__call__(data=data_order, error=err_order) - - # Interpolate - idx = np.isfinite(spec_estimate) - estimate_spl = UnivariateSpline(wave_grid[idx], spec_estimate[idx], k=3, s=0, ext=0) - - # ################################################## - # Second, do the extraction to get the best estimate - # ################################################## # Define wavelength grid with oversampling of 3 (should be enough) wave_grid_os = oversample_grid(wave_grid, n_os=3) - wave_grid_os = wave_grid_os[wave_grid_os > 0.58] + wave_grid_os = wave_grid_os[wave_grid_os > ORDER2_SHORT_CUTOFF] # Initialize the Engine. engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid_os, - orders=[order], - mask_trace_profile=[mask_fit]) + mask_trace_profile=[mask_fit], + orders=[order],) # Find the tikhonov factor. # Initial pass with tikfac_range. - if tikfac_log_range is None: - guess_factor = engine.estimate_tikho_factors(estimate_spl) - log_guess = np.log10(guess_factor) - factors = np.log_range(log_guess - 2, log_guess + 8, 10) - else: - factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1] + 8, 10) - all_tests = engine.get_tikho_tests(factors, data=data_order, error=err_order) - tikfac, mode, _ = engine.best_tikho_factor(tests=all_tests, fit_mode='all') + factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1], 10) + all_tests = engine.get_tikho_tests(factors, data_order, err_order) + tikfac = engine.best_tikho_factor(tests=all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) factors = np.logspace(tikfac - 2, tikfac + 2, 20) - tiktests = engine.get_tikho_tests(factors, data=data_order, error=err_order) - tikfac, mode, _ = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') - all_tests = append_tiktests(all_tests, tiktests) + tiktests = engine.get_tikho_tests(factors, data_order, err_order) + tikfac = engine.best_tikho_factor(tiktests, fit_mode='d_chi2') + all_tests = _append_tiktests(all_tests, tiktests) # Run the extract method of the Engine. - f_k_final = engine.__call__(data=data_order, error=err_order, tikhonov=True, factor=tikfac) + f_k_final = engine(data_order, err_order, tikhonov=True, factor=tikfac) # Save binned spectra in a list of SingleSpecModels for optional output spec_list = [] @@ -826,41 +908,38 @@ def throughput(wavelength): f_k = all_tests['solution'][idx, :] # Build 1d spectrum integrated over pixels - spec_ord = f_to_spec(f_k, wave_grid_os, ref_file_args, wave_grid, - np.all(mask_rebuild, axis=0)[valid_cols], sp_ord=order) - populate_tikho_attr(spec_ord, all_tests, idx, order) + spec_ord = _f_to_spec(f_k, wave_grid_os, ref_file_args, wave_grid, + np.all(mask_rebuild, axis=0)[valid_cols], order) + _populate_tikho_attr(spec_ord, all_tests, idx, order) # Add the result to spec_list spec_list.append(spec_ord) - # ########################################## - # Third, rebuild trace, including bad pixels - # ########################################## - # Initialize the Engine. + + # Rebuild trace, including bad pixels engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid_os, - orders=[order], - mask_trace_profile=[mask_rebuild]) - - # Project on detector and save in dictionary + mask_trace_profile=[mask_rebuild], + orders=[order],) model = engine.rebuild(f_k_final, fill_value=np.nan) # Build 1d spectrum integrated over pixels - spec_ord = f_to_spec(f_k_final, wave_grid_os, ref_file_args, wave_grid, - np.all(mask_rebuild, axis=0)[valid_cols], sp_ord=order) + spec_ord = _f_to_spec(f_k_final, wave_grid_os, ref_file_args, wave_grid, + np.all(mask_rebuild, axis=0)[valid_cols], order) spec_ord.meta.soss_extract1d.factor = tikfac spec_ord.meta.soss_extract1d.type = 'OBSERVATION' # Add the result to spec_list spec_list.append(spec_ord) - return model, spec_list # Remove bad pixels that are not modeled for pixel number -def extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='model', tracemodels=None): - """Perform the box-extraction on the image, while using the trace model to +def _extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='model', tracemodels=None): + """ + Perform the box-extraction on the image, while using the trace model to correct for contamination. + Parameters ---------- decontaminated_data : array[float] @@ -879,6 +958,7 @@ def extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='mo 'model' option uses `tracemodels` to replace the bad pixels. tracemodels : dict Dictionary of the modeled detector images for each order. + Returns ------- fluxes, fluxerrs, npixels : dict @@ -886,15 +966,13 @@ def extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='mo """ # Init models with an empty dictionary if not given if tracemodels is None: - tracemodels = dict() + tracemodels = {} # Which orders to extract (extract the ones with given box aperture). order_list = box_weights.keys() # Create dictionaries for the output spectra. - fluxes = dict() - fluxerrs = dict() - npixels = dict() + fluxes, fluxerrs, npixels = {}, {}, {} log.info('Performing the box extraction.') @@ -955,6 +1033,7 @@ def run_extract1d(input_model, pastasoss_ref_name, specprofile_ref_name, speckernel_ref_name, subarray, soss_filter, soss_kwargs): """Run the spectral extraction on NIRISS SOSS data. + Parameters ---------- input_model : DataModel @@ -966,7 +1045,7 @@ def run_extract1d(input_model, pastasoss_ref_name, speckernel_ref_name : str Name of the speckernel reference file. subarray : str - Subarray on which the data were recorded; one of 'SUBSTRIPT96', + Subarray on which the data were recorded; one of 'SUBSTRIP96', 'SUBSTRIP256' or 'FULL'. soss_filter : str Filter in place during observations; one of 'CLEAR' or 'F277W'. @@ -990,7 +1069,7 @@ def run_extract1d(input_model, pastasoss_ref_name, specprofile_ref = datamodels.SpecProfileModel(specprofile_ref_name) speckernel_ref = datamodels.SpecKernelModel(speckernel_ref_name) - ref_files = dict() + ref_files = {} ref_files['pastasoss'] = pastasoss_ref ref_files['specprofile'] = specprofile_ref ref_files['speckernel'] = speckernel_ref @@ -1002,15 +1081,13 @@ def run_extract1d(input_model, pastasoss_ref_name, if wave_grid_in is not None: log.info(f'Loading wavelength grid from {wave_grid_in}.') wave_grid = datamodels.SossWaveGridModel(wave_grid_in).wavegrid - # Make sure it as the correct precision + # Make sure it has the correct precision wave_grid = wave_grid.astype('float64') else: - # wave_grid will be estimated later in the first call of `model_image` + # wave_grid will be estimated later in the first call of `_model_image` log.info('Wavelength grid was not specified. Setting `wave_grid` to None.') wave_grid = None - # TODO: Maybe not unpack yet. Use SpecModel attributes - # to allow for multiple orders? Create unpacking function. # Convert estimate to cubic spline if given. # It should be a SpecModel or a file name (string) estimate = soss_kwargs.pop('estimate') @@ -1037,8 +1114,7 @@ def run_extract1d(input_model, pastasoss_ref_name, output_references = datamodels.SossExtractModel() output_references.update(input_model) - all_tracemodels = dict() - all_box_weights = dict() + all_tracemodels, all_box_weights = {}, {} # Convert to Cube if datamodels is an ImageModel if isinstance(input_model, datamodels.ImageModel): @@ -1086,7 +1162,7 @@ def run_extract1d(input_model, pastasoss_ref_name, if soss_kwargs['subtract_background']: log.info('Applying background subtraction.') bkg_mask = make_background_mask(scidata, width=40) - scidata_bkg, col_bkg, npix_bkg = soss_background(scidata, scimask, bkg_mask=bkg_mask) + scidata_bkg, col_bkg = soss_background(scidata, scimask, bkg_mask) else: log.info('Skip background subtraction.') scidata_bkg = scidata @@ -1094,7 +1170,9 @@ def run_extract1d(input_model, pastasoss_ref_name, # Pre-compute the weights for box extraction (used in modeling and extraction) args = (ref_files, scidata_bkg.shape) - box_weights, wavelengths = compute_box_weights(*args, width=soss_kwargs['width']) + box_weights, wavelengths = _compute_box_weights(*args, width=soss_kwargs['width']) + + # FIXME: hardcoding the substrip96 weights to unity is a band-aid solution if subarray == 'SUBSTRIP96': box_weights['Order 2'] = np.ones((96, 2048)) @@ -1102,7 +1180,7 @@ def run_extract1d(input_model, pastasoss_ref_name, if soss_filter == 'CLEAR' and generate_model: # Model the image. - kwargs = dict() + kwargs = {} kwargs['estimate'] = estimate kwargs['tikfac'] = soss_kwargs['tikfac'] kwargs['max_grid_size'] = soss_kwargs['max_grid_size'] @@ -1111,8 +1189,8 @@ def run_extract1d(input_model, pastasoss_ref_name, kwargs['wave_grid'] = wave_grid kwargs['threshold'] = soss_kwargs['threshold'] - result = model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, **kwargs) - tracemodels, soss_kwargs['tikfac'], logl, wave_grid, spec_list = result + result = _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, **kwargs) + tracemodels, soss_kwargs['tikfac'], _, wave_grid, spec_list = result # Add atoca spectra to multispec for output for spec in spec_list: @@ -1129,24 +1207,23 @@ def run_extract1d(input_model, pastasoss_ref_name, raise ValueError(msg) else: # Return empty tracemodels and no spec_list - tracemodels = dict() + tracemodels = {} spec_list = None # Decontaminate the data using trace models (if tracemodels not empty) - data_to_extract = decontaminate_image(scidata_bkg, tracemodels, subarray) + data_to_extract = _decontaminate_image(scidata_bkg, tracemodels, subarray) if soss_kwargs['bad_pix'] == 'model': # Generate new trace models for each individual decontaminated orders - # TODO: Use the sum of tracemodels so it can be applied even w/o decontamination bad_pix_models = tracemodels else: bad_pix_models = None # Use the bad pixel models to perform a de-contaminated extraction. - kwargs = dict() + kwargs = {} kwargs['bad_pix'] = soss_kwargs['bad_pix'] kwargs['tracemodels'] = bad_pix_models - result = extract_image(data_to_extract, scierr, scimask, box_weights, **kwargs) + result = _extract_image(data_to_extract, scierr, scimask, box_weights, **kwargs) fluxes, fluxerrs, npixels = result # Save trace models for output reference diff --git a/jwst/extract_1d/soss_extract/soss_syscor.py b/jwst/extract_1d/soss_extract/soss_syscor.py index 36d66aa4f4..45589c851c 100644 --- a/jwst/extract_1d/soss_extract/soss_syscor.py +++ b/jwst/extract_1d/soss_extract/soss_syscor.py @@ -6,64 +6,7 @@ log.setLevel(logging.DEBUG) -def make_profile_mask(ref_2d_profile, threshold=1e-3): - """Build a mask of the trace based on the 2D profile reference file. - - Parameters - ---------- - ref_2d_profile : array[float] - The 2d trace profile reference. - threshold : float - Threshold value for excluding pixels based on ref_2d_profile. - - Returns - ------- - bkg_mask : array[bool] - Pixel mask in the trace based on the 2d profile reference file. - """ - - bkg_mask = (ref_2d_profile > threshold) - - return bkg_mask - - -def aperture_mask(xref, yref, halfwidth, shape): - """Build a mask of the trace based on the trace positions. - - Parameters - ---------- - xref : array[float] - The reference x-positions. - yref : array[float] - The reference y-positions. - halfwidth : float - Size of the aperture mask used when extracting the trace - positions from the data. - shape : Tuple(int, int) - The shape of the array to be masked. - - Returns - ------- - aper_mask : array[bool] - Pixel mask in the trace based on the given trace positions. - """ - - # Create a coordinate grid. - x = np.arange(shape[1]) - y = np.arange(shape[0]) - xx, yy = np.meshgrid(x, y) - - # Interpolate the trace positions onto the grid. - sort = np.argsort(xref) - ytrace = np.interp(x, xref[sort], yref[sort]) - - # Compute the aperture mask. - aper_mask = np.abs(yy - ytrace) > halfwidth - - return aper_mask - - -def soss_background(scidata, scimask, bkg_mask=None): +def soss_background(scidata, scimask, bkg_mask): """Compute a columnwise background for a SOSS observation. Parameters @@ -74,7 +17,7 @@ def soss_background(scidata, scimask, bkg_mask=None): Boolean mask of pixels to be excluded. bkg_mask : array[bool] Boolean mask of pixels to be excluded because they are in - the trace, typically constructed with make_profile_mask. + the trace, typically constructed with make_background_mask. Returns ------- @@ -82,30 +25,18 @@ def soss_background(scidata, scimask, bkg_mask=None): Background-subtracted image col_bkg : array[float] Column-wise background values - npix_bkg : array[float] - Number of pixels used to calculate each column value in col_bkg """ # Check the validity of the input. data_shape = scidata.shape - if scimask.shape != data_shape: - msg = 'scidata and scimask must have the same shape.' + if (scimask.shape != data_shape) or (bkg_mask.shape != data_shape): + msg = 'scidata, scimask, and bkg_mask must all have the same shape.' log.critical(msg) raise ValueError(msg) - if bkg_mask is not None: - if bkg_mask.shape != data_shape: - msg = 'scidata and bkg_mask must have the same shape.' - log.critical(msg) - raise ValueError(msg) - # Combine the masks and create a masked array. - if bkg_mask is not None: - mask = scimask | bkg_mask - else: - mask = scimask - + mask = scimask | bkg_mask scidata_masked = np.ma.array(scidata, mask=mask) # Mask additional pixels using sigma-clipping. @@ -115,15 +46,14 @@ def soss_background(scidata, scimask, bkg_mask=None): # Compute the mean for each column and record the number of pixels used. col_bkg = scidata_clipped.mean(axis=0) col_bkg = np.where(np.all(scidata_clipped.mask, axis=0), 0., col_bkg) - npix_bkg = (~scidata_clipped.mask).sum(axis=0) # Background subtract the science data. scidata_bkg = scidata - col_bkg - return scidata_bkg, col_bkg, npix_bkg + return scidata_bkg, col_bkg -def make_background_mask(deepstack, width=28): +def make_background_mask(deepstack, width): """Build a mask of the pixels considered to contain the majority of the flux, and should therefore not be used to compute the background. @@ -141,19 +71,16 @@ def make_background_mask(deepstack, width=28): bkg_mask : array[bool] Pixel mask in the trace based on the deepstack or non-finite in the image. - :rtype: array[bool] """ # Get the dimensions of the input image. - nrows, ncols = np.shape(deepstack) + nrows, _ = np.shape(deepstack) # Set the appropriate quantile for masking based on the subarray size. if nrows == 96: # SUBSTRIP96 - quantile = 100 * (1 - width / 96) # Mask 1 order worth of pixels. - elif nrows == 256: # SUBSTRIP256 - quantile = 100 * (1 - 2 * width / 256) # Mask 2 orders worth of pixels. - elif nrows == 2048: # FULL - quantile = 100 * (1 - 2 * width / 2048) # Mask 2 orders worth of pixels. + quantile = 100 * (1 - width / nrows) # Mask 1 order worth of pixels. + elif nrows in [256, 2048]: # SUBSTRIP256, FULL + quantile = 100 * (1 - 2 * width / nrows) # Mask 2 orders worth of pixels. else: msg = (f'Unexpected image dimensions, expected nrows = 96, 256 or 2048, ' f'got nrows = {nrows}.') @@ -165,92 +92,4 @@ def make_background_mask(deepstack, width=28): # Mask pixels above the threshold value. with np.errstate(invalid='ignore'): - bkg_mask = (deepstack > threshold) | ~np.isfinite(deepstack) - - return bkg_mask - - -def soss_oneoverf_correction(scidata, scimask, deepstack, bkg_mask=None, - zero_bias=False): - """Compute a column-wise correction to the 1/f noise on the difference image - of an individual SOSS integration (i.e. an individual integration - a deep - image of the same observation). - - Parameters - ---------- - scidata : array[float] - Image of the SOSS trace. - scimask : array[boo] - Boolean mask of pixels to be excluded based on the DQ values. - deepstack : array[float] - Deep image of the trace constructed by combining - individual integrations of the observation. - bkg_mask : array[bool] - Boolean mask of pixels to be excluded because they are in the trace, - typically constructed with make_profile_mask. - zero_bias : bool - If True, the corrections to individual columns will be - adjusted so that their mean is zero. - - Returns - ------- - scidata_cor : array[float] - The 1/f-corrected image - col_cor : array[float] - The column-wise correction values - npix_cor : array[float] - Number of pixels used in each column - bias : float - Net change to the image, if zero_bias was False - """ - - # Check the validity of the input. - data_shape = scidata.shape - - if scimask.shape != data_shape: - msg = 'scidata and scimask must have the same shape.' - log.critical(msg) - raise ValueError(msg) - - if deepstack.shape != data_shape: - msg = 'scidata and deepstack must have the same shape.' - log.critical(msg) - raise ValueError(msg) - - if bkg_mask is not None: - - if bkg_mask.shape != data_shape: - msg = 'scidata and bkg_mask must have the same shape.' - log.critical(msg) - raise ValueError(msg) - - # Subtract the deep stack from the image. - diffimage = scidata - deepstack - - # Combine the masks and create a masked array. - mask = scimask | ~np.isfinite(deepstack) - - if bkg_mask is not None: - mask = mask | bkg_mask - - diffimage_masked = np.ma.array(diffimage, mask=mask) - - # Mask additional pixels using sigma-clipping. - sigclip = SigmaClip(sigma=3, maxiters=None, cenfunc='mean') - diffimage_clipped = sigclip(diffimage_masked, axis=0) - - # Compute the mean for each column and record the number of pixels used. - col_cor = diffimage_clipped.mean(axis=0) - npix_cor = (~diffimage_clipped.mask).sum(axis=0) - - # Compute the net change to the image. - bias = np.nanmean(col_cor) - - # Set the net bias to zero. - if zero_bias: - col_cor = col_cor - bias - - # Apply the 1/f correction to the image. - scidata_cor = scidata - col_cor - - return scidata_cor, col_cor, npix_cor, bias + return (deepstack > threshold) | ~np.isfinite(deepstack) diff --git a/jwst/extract_1d/soss_extract/soss_utils.py b/jwst/extract_1d/soss_extract/soss_utils.py deleted file mode 100644 index 88c8b3212a..0000000000 --- a/jwst/extract_1d/soss_extract/soss_utils.py +++ /dev/null @@ -1,177 +0,0 @@ -import numpy as np -import logging - -log = logging.getLogger(__name__) -log.setLevel(logging.DEBUG) - - -def zero_roll(a, shift): - """Like np.roll but the wrapped around part is set to zero. - Only works along the first axis of the array. - - Parameters - ---------- - a : array - The input array. - shift : int - The number of rows to shift by. - - Returns - ------- - result : array - The array with the rows shifted. - """ - - result = np.zeros_like(a) - if shift >= 0: - result[shift:] = a[:-shift] - else: - result[:shift] = a[-shift:] - - return result - - -def robust_polyfit(x, y, order, maxiter=5, nstd=3.): - """Perform a robust polynomial fit. - - Parameters - ---------- - x : array[float] - x data to fit. - y : array[float] - y data to fit. - order : int - polynomial order to use. - maxiter : int, optional - number of iterations for rejecting outliers. - nstd : float, optional - number of standard deviations to use when rejecting outliers. - - Returns - ------- - param : array[float] - best-fit polynomial parameters. - """ - - mask = np.ones_like(x, dtype='bool') - for niter in range(maxiter): - - # Fit the data and evaluate the best-fit model. - param = np.polyfit(x[mask], y[mask], order) - yfit = np.polyval(param, x) - - # Compute residuals and mask outliers. - res = y - yfit - stddev = np.std(res) - mask = np.abs(res) <= nstd * stddev - - return param - - -def get_image_dim(image, header=None): - """Determine the properties of the image array. - - Parameters - ---------- - image : array[float] - A 2D image of the detector. - header : astropy.io.fits.Header object, optional - The header from one of the SOSS reference files. - - Returns - ------- - dimx, dimy : int - X and Y dimensions of the stack array. - xos, yos : int - Oversampling factors in x and y dimensions of the stack array. - xnative, ynative : int - Size of stack image x and y dimensions, in native pixels. - padding : int - Amount of padding around the image, in native pixels. - refpix_mask : array[bool] - Boolean array indicating which pixels are light-sensitive (True) - and which are reference pixels (False). - """ - - # Dimensions of the subarray. - dimy, dimx = np.shape(image) - - # If no header was passed we have to check all possible sizes. - if header is None: - - # Initialize padding to zero in this case because it is not a reference file. - padding = 0 - - # Assume the stack is a valid SOSS subarray. - # FULL: 2048x2048 or 2040x2040 (working pixels) or multiple if oversampled. - # SUBSTRIP96: 2048x96 or 2040x96 (working pixels) or multiple if oversampled. - # SUBSTRIP256: 2048x256 or 2040x252 (working pixels) or multiple if oversampled. - - # Check if the size of the x-axis is valid. - if (dimx % 2048) == 0: - xnative = 2048 - xos = int(dimx // 2048) - - elif (dimx % 2040) == 0: - xnative = 2040 - xos = int(dimx // 2040) - - else: - log_message = f'Stack X dimension has unrecognized size of {dimx}. Accepts 2048, 2040 or multiple of.' - log.critical(log_message) - raise ValueError(log_message) - - # Check if the y-axis is consistent with the x-axis. - if int(dimy / xos) in [96, 256, 252, 2040, 2048]: - yos = np.copy(xos) - ynative = int(dimy / yos) - - else: - log_message = f'Stack Y dimension ({dimy}) is inconsistent with stack X' \ - f'dimension ({dimx}) for acceptable SOSS arrays' - log.critical(log_message) - raise ValueError(log_message) - - # Create a boolean mask indicating which pixels are not reference pixels. - refpix_mask = np.ones_like(image, dtype='bool') - if xnative == 2048: - # Mask out the left and right columns of reference pixels. - refpix_mask[:, :xos * 4] = False - refpix_mask[:, -xos * 4:] = False - - if ynative == 2048: - # Mask out the top and bottom rows of reference pixels. - refpix_mask[:yos * 4, :] = False - refpix_mask[-yos * 4:, :] = False - - if ynative == 256: - # Mask the top rows of reference pixels. - refpix_mask[-yos * 4:, :] = False - - else: - # Read the oversampling and padding from the header. - padding = int(header['PADDING']) - xos, yos = int(header['OVERSAMP']), int(header['OVERSAMP']) - - # Check that the stack respects its intended format. - if (dimx / xos - 2 * padding) not in [2048]: - log_message = 'The header passed is inconsistent with the X dimension of the stack.' - log.critical(log_message) - raise ValueError(log_message) - else: - xnative = 2048 - - if (dimy / yos - 2 * padding) not in [96, 256, 2048]: - log_message = 'The header passed is inconsistent with the Y dimension of the stack.' - log.critical(log_message) - raise ValueError(log_message) - else: - ynative = int(dimy / yos - 2 * padding) - - # The trace file contains no reference pixels so all pixels are good. - refpix_mask = np.ones_like(image, dtype='bool') - - log.debug('Data dimensions:') - log.debug(f'dimx={dimx}, dimy={dimy}, xos={xos}, yos={yos}, xnative={xnative}, ynative={ynative}') - - return dimx, dimy, xos, yos, xnative, ynative, padding, refpix_mask diff --git a/jwst/extract_1d/soss_extract/tests/__init__.py b/jwst/extract_1d/soss_extract/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/extract_1d/soss_extract/tests/conftest.py b/jwst/extract_1d/soss_extract/tests/conftest.py new file mode 100644 index 0000000000..3e48416145 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/conftest.py @@ -0,0 +1,318 @@ +import pytest +import numpy as np +from scipy.signal import savgol_filter +from functools import partial +from jwst.extract_1d.soss_extract import atoca +from jwst.extract_1d.soss_extract import atoca_utils as au +from stdatamodels.jwst.datamodels import PastasossModel + +""" +Create a miniature, slightly simplified model of the SOSS detector/optics. +Use those to instantiate an extraction engine, use the engine to create mock data +from a known input spectrum, and then check if the engine can retrieve that spectrum +from the data. + +The model has the following features: +- Factor-of-10 smaller along each dimension +- Similar wavelengths for each of the two orders +- Partially overlapping traces for the two orders +- Randomly-selected bad pixels in the data +- Wave grid of size ~100 with varying resolution +- Triangle function throughput for each spectral order +- Kernel is also a triangle function peaking at the center, or else unity for certain tests +- (partial) Mock of the Pastasoss reference model +""" + +PWCPOS = 245.85932900002442 +DATA_SHAPE = (25,200) +WAVE_BNDS_O1 = [2.8, 0.8] +WAVE_BNDS_O2 = [1.4, 0.5] +WAVE_BNDS_GRID = [0.7, 2.7] +ORDER1_SCALING = 20.0 +ORDER2_SCALING = 2.0 +TRACE_END_IDX = [DATA_SHAPE[1],180] +SPECTRAL_SLOPE = 2 + + +@pytest.fixture(scope="package") +def wave_map(): + wave_ord1 = np.linspace(WAVE_BNDS_O1[0], WAVE_BNDS_O1[1], DATA_SHAPE[1]) + wave_ord1 = np.ones(DATA_SHAPE)*wave_ord1[np.newaxis, :] + + wave_ord2 = np.linspace(WAVE_BNDS_O2[0], WAVE_BNDS_O2[1], DATA_SHAPE[1]) + wave_ord2 = np.ones(DATA_SHAPE)*wave_ord2[np.newaxis,:] + # add a small region of zeros to mimic what is input to the step from ref files + wave_ord2[:,TRACE_END_IDX[1]:] = 0.0 + + return [wave_ord1, wave_ord2] + + +@pytest.fixture(scope="package") +def trace_profile(wave_map): + """order 2 is partially on top of, partially not on top of order 1 + give order 2 some slope to simulate that""" + # order 1 + DATA_SHAPE = wave_map[0].shape + ord1 = np.zeros((DATA_SHAPE[0])) + ord1[3:9] = 0.2 + ord1[2] = 0.1 + ord1[9] = 0.1 + profile_ord1 = np.ones(DATA_SHAPE)*ord1[:, np.newaxis] + + # order 2 + yy, xx = np.meshgrid(np.arange(DATA_SHAPE[0]), np.arange(DATA_SHAPE[1])) + yy = yy.astype(np.float32) - xx.astype(np.float32)*0.08 + yy = yy.T + + profile_ord2 = np.zeros_like(yy) + full = (yy >= 3) & (yy < 9) + half0 = (yy >= 9) & (yy < 11) + half1 = (yy >= 1) & (yy < 3) + profile_ord2[full] = 0.2 + profile_ord2[half0] = 0.1 + profile_ord2[half1] = 0.1 + + return [profile_ord1, profile_ord2] + + +@pytest.fixture(scope="package") +def trace1d(wave_map, trace_profile): + """For each order, return tuple (xtrace, ytrace, wavetrace)""" + + trace_list = [] + for order in [0,1]: + + profile = trace_profile[order] + wave2d = wave_map[order].copy() #avoid modifying wave_map, it's needed elsewhere! + end_idx = TRACE_END_IDX[order] + + # find mean y-index at each x where trace_profile is nonzero + shp = profile.shape + xx, yy = np.mgrid[:shp[0], :shp[1]] + xx = xx.astype("float") + xx[profile==0] = np.nan + mean_trace = np.nanmean(xx, axis=0) + # same strategy for wavelength + wave2d[profile==0] = np.nan + mean_wave = np.nanmean(wave2d, axis=0) + + # smooth it. we know it should be linear so use 1st order poly and large box size + ytrace = savgol_filter(mean_trace, mean_trace.size-1, 1) + wavetrace = savgol_filter(mean_wave, mean_wave.size-1, 1) + + # apply cutoff + wavetrace = wavetrace[:end_idx] + ytrace = ytrace[:end_idx] + xtrace = np.arange(0, end_idx) + + trace_list.append((xtrace, ytrace, wavetrace)) + + return trace_list + + +@pytest.fixture(scope="package") +def wave_grid(): + """wave_grid has smaller spacings in some places than others + and is not backwards order like the wave map + Two duplicates are in there on purpose for testing""" + lo0 = np.linspace(WAVE_BNDS_GRID[0], 1.2, 16) + hi = np.linspace(1.2, 1.7, 46) + lo2 = np.linspace(1.7, WAVE_BNDS_GRID[1], 31) + return np.concatenate([lo0, hi, lo2]) + + +@pytest.fixture(scope="package") +def throughput(): + """make a triangle function for each order but with different peak wavelength + """ + + def filter_function(wl, wl_max): + """Set free parameters to roughly mimic throughput functions on main""" + maxthru = 0.4 + thresh = 0.01 + scaling = 0.3 + dist = np.abs(wl - wl_max) + thru = maxthru - dist*scaling + thru[thru0] = 0 + if cut_low is not None: + trace[:,:cut_low] = 1 + if cut_hi is not None: + trace[:,cut_hi:] = 1 + return trace.astype(bool) + + trace_o1 = mask_from_trace(trace_profile[0], cut_low=0, cut_hi=199) + trace_o2 = mask_from_trace(trace_profile[1], cut_low=0, cut_hi=175) + return [trace_o1, trace_o2] + + +@pytest.fixture(scope="package") +def detector_mask(): + """Add a few random bad pixels""" + rng = np.random.default_rng(42) + mask = np.zeros(DATA_SHAPE, dtype=bool) + bad = rng.choice(mask.size, 100) + bad = np.unravel_index(bad, DATA_SHAPE) + mask[bad] = 1 + return mask + + +@pytest.fixture(scope="package") +def engine(wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + detector_mask, +): + return atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + global_mask=detector_mask) + + +def f_lam(wl, m=SPECTRAL_SLOPE, b=0): + """ + Estimator for flux as function of wavelength + Returns linear function of wl with slope m and intercept b + + This function is also used in this test suite as + """ + return m*wl + b + + +@pytest.fixture(scope="package") +def imagemodel(engine, detector_mask): + """ + use engine.rebuild to make an image model from an expected f(lambda). + Then we can ensure it round-trips + """ + + rng = np.random.default_rng(seed=42) + shp = engine.trace_profile[0].shape + + # make the detector bad values NaN, but leave the trace masks alone + # in reality, of course, this is backward: detector bad values + # would be determined from data + data = engine.rebuild(f_lam, fill_value=0.0) + data[detector_mask] = np.nan + + # add noise + noise_scaling = 3e-5 + data += noise_scaling*rng.standard_normal(shp) + + # error random, all positive, but also always larger than a certain value + # to avoid very large values of data / error + error = noise_scaling*(rng.standard_normal(shp)**2 + 0.5) + + # TODO: Why does the data here have some kind of beat frequency? + # TODO: why does the data here have one deep negative bar at the end of each spectral order? + + return data, error + + +@pytest.fixture(scope="module") +def refmodel(trace1d): + """Mock Pastasoss reference model with spatial dimensions scaled + down by a factor of 10. Since the traces are just linear, the polynomials + also have coefficients equal to 0 except for the constant and linear terms""" + model = PastasossModel() + model.meta.pwcpos_cmd = 245.76 + + trace0 = {"pivot_x": 189.0, + "pivot_y": 5.0, + "spectral_order": 1, + "trace": np.array([trace1d[0][0], trace1d[0][1]], dtype=np.float64).T, + "padding": 0,} + trace1 = {"pivot_x": 168.0, + "pivot_y": 20.0, + "spectral_order": 2, + "trace": np.array([trace1d[1][0], trace1d[1][1]], dtype=np.float64).T,} + model.traces = [trace0, trace1] + + wavecal0 = {"coefficients": [WAVE_BNDS_O1[0], -2.0,] + [0.0 for i in range(19)], + "polynomial_degree": 5, + "scale_extents": [[0, -1.03552000e-01], [DATA_SHAPE[1], 1.62882080e-01]]} + wavecal1 = {"coefficients": [WAVE_BNDS_O2[0], -1.0,] + [0.0 for i in range(8)], + "polynomial_degree": 3, + "scale_extents": [[0, 245.5929], [DATA_SHAPE[1], 245.9271]]} + model.wavecal_models = [wavecal0, wavecal1] + + thru0 = {"spectral_order": 1, + "wavelength": np.linspace(0.5, 5.5, 501), + "throughput": np.ones((501,)),} #peaks around 1.22 at value 0.37 + thru1 = {"spectral_order": 2, + "wavelength": np.linspace(0.5, 5.5, 501), + "throughput": np.ones((501,)),} #peaks around 0.7 at value of 0.16 + model.throughputs = [thru0, thru1] + + return model + + +@pytest.fixture +def ref_files(refmodel): + ref_files = {"pastasoss": refmodel} + ref_files["subarray"] = "SUBSTRIP256" + ref_files["pwcpos"] = PWCPOS + return ref_files \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py new file mode 100644 index 0000000000..bf0d965a16 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -0,0 +1,417 @@ +import pytest +import numpy as np +from functools import partial +from scipy.sparse import csr_matrix +from jwst.extract_1d.soss_extract import atoca +from jwst.extract_1d.soss_extract.tests.conftest import ( + SPECTRAL_SLOPE, f_lam, DATA_SHAPE, WAVE_BNDS_O1, WAVE_BNDS_O2) + +"""Tests for the ATOCA extraction engine, taking advantage of the miniature +model set up by conftest.py. +The test_call() function ensures that the engine can retrieve the spectrum +with SPECTRAL_SLOPE that we put into the data, which implicitly checks +a lot of the matrix math.""" + + +def test_extraction_engine_init( + wave_map, + trace_profile, + throughput, + wave_grid, + mask_trace_profile, + detector_mask, + engine, +): + """Test the init of the engine with default/good inputs""" + + # test wave_grid became unique + assert engine.wave_grid.dtype == np.float64 + unq = np.unique(wave_grid) + assert np.allclose(engine.wave_grid, unq) + assert engine.n_wavepoints == unq.size + + for order in [0,1]: + # test assignment of attributes and conversion to expected float64 dtype + assert engine.wave_map[order].dtype == np.float64 + assert engine.trace_profile[order].dtype == np.float64 + assert engine.kernels[order].dtype == np.float64 + assert engine.mask_trace_profile[order].dtype == np.bool_ + + assert np.allclose(engine.wave_map[order], wave_map[order]) + assert np.allclose(engine.trace_profile[order], trace_profile[order]) + assert np.allclose(engine.mask_trace_profile[order], mask_trace_profile[order]) + + # test derived attributes + assert engine.data_shape == DATA_SHAPE + assert engine.n_orders == 2 + + # test wave_p and wave_m. separate unit test for their calculation + for att in ["wave_p", "wave_m"]: + wave = getattr(engine, att) + assert wave.dtype == np.float64 + assert wave.shape == (2,)+DATA_SHAPE + + # test _get_i_bounds + assert len(engine.i_bounds) == 2 + for order in [0,1]: + assert len(engine.i_bounds[order]) == 2 + assert engine.i_bounds[order][0] >= 0 + assert engine.i_bounds[order][1] < DATA_SHAPE[1] + + # test to ensure that wave_map is considered for bounds + # in order 1 the wave_map is more restrictive on the shortwave end + # in order 2 the wave_grid is more restrictive on the longwave end + # in order 1 no restriction on the longwave end so get the full extent + # in order 2 no restriction on the shortwave end so get the full extent + assert engine.i_bounds[0][0] > 0 + assert engine.i_bounds[1][1] < engine.n_wavepoints + + # TODO: off-by-one error here. why does this fail? + # check what this looks like on a real run on main + # assert engine.i_bounds[0][1] == engine.n_wavepoints + # assert engine.i_bounds[1][0] == 0 + + + # test _get_masks + # ensure they all include the input detector bad pixel mask + for mask in [engine.mask_ord[0], engine.mask_ord[1], engine.mask, engine.general_mask]: + assert mask.dtype == np.bool_ + assert mask.shape == DATA_SHAPE + assert np.all(mask[detector_mask == 1]) + + # general_mask should be a copy of mask + assert np.allclose(engine.mask, engine.general_mask) + + wave_bnds = [WAVE_BNDS_O1, WAVE_BNDS_O2] + for order in [0,1]: + # ensure wavelength bounds from wave_grid are respected in mask_ord + mask = engine.mask_ord[order] + wls = wave_map[order] + lo, hi = wave_bnds[order] + outside = (wls > lo) | (wls < hi) + assert np.all(mask[outside]) + + # a bit paradoxically, engine.mask_ord does not contain a single order's mask_trace_profile + # instead, it's mask_trace_profile[0] AND mask_trace_profile[1], i.e., + # the trace profiles of BOTH orders are UNmasked in the masks of each order + # the general mask and wavelength bounds are then applied, so the + # only difference between mask_ord[0] and mask_ord[1] are the wavelength bounds + # test that NOT all locations masked by mask_trace_profile are masked in mask_ord + assert not np.all(mask[mask_trace_profile[order]]) + # test that all locations masked by both profiles are masked in mask_ord + combined_profile = mask_trace_profile[0] & mask_trace_profile[1] + assert np.all(mask[combined_profile]) + + # test throughput function conversion to array + for order in [0,1]: + thru = engine.throughput[order] + assert thru.size == engine.n_wavepoints + assert np.all(thru >= 0) + assert thru.dtype == np.float64 + + # test kernel is cast to proper shape for input (trivial) kernel + for order in [0,1]: + n_valid = engine.i_bounds[order][1] - engine.i_bounds[order][0] + expected_shape = (n_valid, engine.n_wavepoints) + assert engine.kernels[order].shape == expected_shape + # for trivial kernel only one element per row is nonzero + assert engine.kernels[order].count_nonzero() == expected_shape[0] + + # test weights. see separate unit tests to ensure the calculation is correct + for order in [0,1]: + n_valid = engine.i_bounds[order][1] - engine.i_bounds[order][0] + weights = engine.weights[order] + k_idx = engine.weights_k_idx[order] + assert weights.dtype == np.float64 + assert np.issubdtype(k_idx.dtype, np.integer) + + # TODO: why is weights the same size as ~engine.mask, and not ~engine.mask_ord? + assert weights.shape == (np.sum(~engine.mask), n_valid) + assert k_idx.shape[0] == weights.shape[0] + + # test assignment of empty attributes + for att in ["w_t_wave_c", "tikho_mat", "_tikho_mat"]: + assert hasattr(engine, att) + assert getattr(engine, "pixel_mapping") == [None, None] + + +def test_extraction_engine_bad_inputs( + wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + detector_mask, +): + # not enough good pixels in order + with pytest.raises(atoca.MaskOverlapError): + detector_mask = np.ones_like(detector_mask) + detector_mask[5:7,50:55] = 0 #still a few good pixels but very few + atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + global_mask=detector_mask) + + # wrong number of orders + with pytest.raises(ValueError): + atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + global_mask=detector_mask, + orders=[0,]) + + +def test_get_attributes(engine): + # test string input + assert np.allclose(engine.get_attributes("wave_map"), engine.wave_map) + + # test list of strings input + name_list = ["wave_map", "wave_grid"] + att_list = engine.get_attributes(*name_list) + expected = [engine.wave_map, engine.wave_grid] + for i in range(len(expected)): + for j in range(2): #orders + assert np.allclose(att_list[i][j], expected[i][j]) + + # test i_order not None + att_list = engine.get_attributes(*name_list, i_order=1) + expected = [engine.wave_map[1], engine.wave_grid[1]] + for i in range(len(expected)): + assert np.allclose(att_list[i], expected[i]) + + +def test_update_throughput(engine, throughput): + + old_thru = engine.throughput + + # test callable input + new_thru = [throughput[1], throughput[1]] + engine.update_throughput(new_thru) + for i, thru in enumerate(engine.throughput): + assert isinstance(thru, np.ndarray) + assert thru.shape == engine.wave_grid.shape + assert np.allclose(engine.throughput[0], engine.throughput[1]) + + # test array input + # first reset to old throughput + engine.throughput = old_thru + new_thru = [thru*2 for thru in engine.throughput] + engine.update_throughput(new_thru) + for i, thru in enumerate(engine.throughput): + assert np.allclose(thru, old_thru[i]*2) + + # test fail on bad array shape + new_thru = [thru[:-1] for thru in engine.throughput] + with pytest.raises(ValueError): + engine.update_throughput(new_thru) + + # test fail on callable that doesn't return correct array shape + def new_thru_f(wl): + return 1.0 + with pytest.raises(ValueError): + engine.update_throughput([new_thru_f, new_thru_f]) + + +def test_create_kernels(webb_kernels, engine): + """test_atoca_utils.test_get_c_matrix already tests the creation + of individual kernels for different input types, here just ensure + the options get passed into that function properly""" + + kernels_0 = engine._create_kernels(webb_kernels) + kernels_1 = engine._create_kernels([None, None]) + + for kernel_list in [kernels_0, kernels_1]: + assert len(kernel_list) == 2 + for order in [0,1]: + kern = kernel_list[order] + assert isinstance(kern, csr_matrix) + assert kern.dtype == np.float64 + + +def test_wave_grid_c(engine): + for order in [0,1]: + n_valid = engine.i_bounds[order][1] - engine.i_bounds[order][0] + assert engine.wave_grid_c(order).size == n_valid + + +def test_set_w_t_wave_c(engine): + """all this does is copy whatever is input""" + product = np.zeros((1,)) + engine._set_w_t_wave_c(0, product) + assert len(engine.w_t_wave_c) == engine.n_orders + assert engine.w_t_wave_c[0] == product + assert engine.w_t_wave_c[1] == [] + assert product is not engine.w_t_wave_c[0] + + +def test_get_pixel_mapping(engine): + + pixel_mapping_0 = engine.get_pixel_mapping(0) + # check attribute is set and identical to output + # check the second one is not set but there is space for it + assert hasattr(engine, "pixel_mapping") + assert len(engine.pixel_mapping) == engine.n_orders + assert np.allclose(engine.pixel_mapping[0].data, pixel_mapping_0.data) + assert engine.pixel_mapping[1] is None + + # set the second one so can check both at once + engine.get_pixel_mapping(1) + for order in [0,1]: + mapping = engine.pixel_mapping[order] + assert mapping.dtype == np.float64 + # TODO: why is this the shape, instead of using mask_ord and only the valid wave_grid? + expected_shape = (np.sum(~engine.mask), engine.wave_grid.size) + assert mapping.shape == expected_shape + + # test that w_t_wave_c is getting saved + w_t_wave_c = engine.w_t_wave_c[order] + assert w_t_wave_c.dtype == np.float64 + assert w_t_wave_c.shape == expected_shape + + # check if quick=True works + mapping_quick = engine.get_pixel_mapping(order, quick=True) + assert np.allclose(mapping.data, mapping_quick.data) + + # check that quick=True does not work if w_t_wave_c unsaved + engine.w_t_wave_c = None + with pytest.raises(AttributeError): + engine.get_pixel_mapping(1, quick=True) + + +def test_rebuild(engine): + + detector_model = engine.rebuild(f_lam) + assert detector_model.dtype == np.float64 + assert detector_model.shape == engine.wave_map[0].shape + + # test that input spectrum is ok as either callable or array + assert np.allclose(detector_model, engine.rebuild(f_lam(engine.wave_grid))) + + # test fill value + detector_model_nans = engine.rebuild(f_lam, fill_value=np.nan) + assert np.allclose(np.isnan(detector_model_nans), engine.general_mask) + + +def test_build_sys(imagemodel, engine): + + data, error = imagemodel + matrix, result = engine.build_sys(data, error) + assert result.size == engine.n_wavepoints + assert matrix.shape == (result.size, result.size) + + +def test_get_detector_model(imagemodel, engine): + + data, error = imagemodel + unmasked_size = np.sum(~engine.mask) + b_matrix, data_matrix = engine.get_detector_model(data, error) + + assert data_matrix.shape == (1, unmasked_size) + assert b_matrix.shape == (unmasked_size, engine.n_wavepoints) + assert np.allclose(data_matrix.toarray()[0], (data/error)[~engine.mask]) + + +def test_estimate_tikho_factors(engine): + + factor = engine.estimate_tikho_factors(f_lam) + assert isinstance(factor, float) + + # very approximate calculation of tik fac looks like + # n_pixels = (~engine.mask).sum() + # flux = f_lam(engine.wave_grid) + # dlam = engine.wave_grid[1:] - engine.wave_grid[:-1] + # print(n_pixels/np.mean(flux[1:] * dlam)) + + +@pytest.fixture(scope="module") +def tikho_tests(imagemodel, engine): + data, error = imagemodel + + log_guess = np.log10(engine.estimate_tikho_factors(f_lam)) + factors = np.logspace(log_guess - 9, log_guess + 9, 19) + return factors, engine.get_tikho_tests(factors, data, error) + + +def test_get_tikho_tests(tikho_tests, engine): + + factors, tests = tikho_tests + unmasked_size = np.sum(~engine.mask) + + # test all the output shapes + assert np.allclose(tests["factors"], factors) + assert tests["solution"].shape == (len(factors), engine.n_wavepoints) + assert tests["error"].shape == (len(factors), unmasked_size) + assert tests["reg"].shape == (len(factors), engine.n_wavepoints-1) + assert tests["chi2"].shape == (len(factors),) + assert tests["chi2_soft_l1"].shape == (len(factors),) + assert tests["chi2_cauchy"].shape == (len(factors),) + assert np.allclose(tests["grid"], engine.wave_grid) + + # test data type is preserved through solve + for key in tests.keys(): + assert tests[key].dtype == np.dtype("float64") + + +def test_best_tikho_factor(engine, tikho_tests): + + input_factors, tests = tikho_tests + fit_modes = ["all", "curvature", "chi2", "d_chi2"] + best_factors = [] + for mode in fit_modes: + factor = engine.best_tikho_factor(tests, mode) + assert isinstance(factor, float) + best_factors.append(factor) + + # ensure fit_mode=all found one of the three others + assert best_factors[0] in best_factors[1:] + + # TODO: test the logic tree by manually changing the tests dict + # this is non-trivial because dchi2 and curvature + # are both derived from other keys in the dictionary, not just the statistical metrics + + +def test_call(engine, tikho_tests, imagemodel): + """ + Run the actual extract method. + Ensure it can retrieve the input spectrum based on f_lam to within a few percent + at all points on the wave_grid. + + Note this round-trip implicitly checks the math of the build_sys, get_detector_model, + _solve, and _solve_tikho, at least at first blush. + """ + data, error = imagemodel + _, tests = tikho_tests + best_factor = engine.best_tikho_factor(tests, "all") + + expected_spectrum = f_lam(engine.wave_grid) + for tikhonov in [True, False]: + spectrum = engine(data, error, tikhonov=tikhonov, factor=best_factor) + diff = (spectrum - expected_spectrum)/expected_spectrum + assert not np.all(np.isnan(diff)) + diff = diff[~np.isnan(diff)] + assert np.all(np.abs(diff) < 0.05) + + # test bad input, failing to put factor in for Tikhonov solver + with pytest.raises(ValueError): + engine(data, error, tikhonov=True) + + +def test_compute_likelihood(engine, imagemodel): + """Ensure log-likelihood is highest for the correct slope""" + + data, error = imagemodel + test_slopes = np.arange(0, 5, 0.5) + logl = [] + for slope in test_slopes: + spectrum = partial(f_lam, m=slope) + logl.append(engine.compute_likelihood(spectrum, data, error)) + + assert np.argmax(logl) == np.argwhere(test_slopes == SPECTRAL_SLOPE) + assert np.all(np.array(logl) < 0) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py new file mode 100644 index 0000000000..72b6df3f90 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -0,0 +1,466 @@ + +import pytest +from jwst.extract_1d.soss_extract import atoca_utils as au +import numpy as np + + +def test_arange_2d(): + + starts = np.array([3,4,5]) + stops = np.ones(starts.shape)*7 + out = au.arange_2d(starts, stops) + + bad = -1 + expected_out = np.array([ + [3,4,5,6], + [4,5,6,bad], + [5,6,bad,bad] + ]) + assert np.allclose(out, expected_out) + + # test bad input catches + starts_wrong_shape = starts[1:] + with pytest.raises(ValueError): + au.arange_2d(starts_wrong_shape, stops) + + stops_too_small = np.copy(stops) + stops_too_small[2] = 4 + with pytest.raises(ValueError): + au.arange_2d(starts, stops_too_small) + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wv_map_bounds(wave_map, dispersion_axis): + """ + top is the low-wavelength end, bottom is high-wavelength end + """ + wave_map = wave_map[0].copy() + wave_map[1,3] = -1 #test skip of bad value + wavelengths = wave_map[0] + + if dispersion_axis == 0: + wave_flip = wave_map.T + else: + wave_flip = wave_map + wave_top, wave_bottom = au._get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) + + # flip the results back so we can re-use the same tests + if dispersion_axis == 0: + wave_top = wave_top.T + wave_bottom = wave_bottom.T + + diff = (wavelengths[1:]-wavelengths[:-1])/2 + diff_lower = np.insert(diff,0,diff[0]) + diff_upper = np.append(diff,diff[-1]) + wave_top_expected = wavelengths-diff_lower + wave_bottom_expected = wavelengths+diff_upper + + # basic test + assert wave_top.shape == wave_bottom.shape == (wave_map.shape[0],)+wavelengths.shape + assert np.allclose(wave_top[0], wave_top_expected) + assert np.allclose(wave_bottom[0], wave_bottom_expected) + + # test skip bad pixel + assert wave_top[1,3] == 0 + assert wave_bottom[1,3] == 0 + + # test bad input error raises + with pytest.raises(ValueError): + au._get_wv_map_bounds(wave_flip, dispersion_axis=2) + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wave_p_or_m(wave_map, dispersion_axis): + """ + Check that the plus and minus side is correctly identified + for strictly ascending and strictly descending wavelengths. + """ + wave_map = wave_map[0].copy() + wave_reverse = np.fliplr(wave_map) + if dispersion_axis == 0: + wave_flip = wave_map.T + wave_reverse = wave_reverse.T + else: + wave_flip = wave_map + + wave_p_0, wave_m_0 = au.get_wave_p_or_m(wave_flip, dispersion_axis=dispersion_axis) + wave_p_1, wave_m_1 = au.get_wave_p_or_m(wave_reverse, dispersion_axis=dispersion_axis) + + if dispersion_axis==0: + wave_p_0 = wave_p_0.T + wave_m_0 = wave_m_0.T + wave_p_1 = wave_p_1.T + wave_m_1 = wave_m_1.T + assert np.all(wave_p_0 >= wave_m_0) + assert np.allclose(wave_p_0, np.fliplr(wave_p_1)) + assert np.allclose(wave_m_0, np.fliplr(wave_m_1)) + + +def test_get_wave_p_or_m_not_ascending(wave_map): + wave_map = wave_map[0].copy() + with pytest.raises(ValueError): + wave_map[0,5] = 2 # make it not strictly ascending + au.get_wave_p_or_m(wave_map, dispersion_axis=1) + + +FIBONACCI = np.array([1,1,2,3,5,8,13,21,35], dtype=float) +@pytest.mark.parametrize("n_os", [1,5]) +def test_oversample_grid(n_os): + + oversample = au.oversample_grid(FIBONACCI, n_os) + + # oversample_grid is supposed to remove any duplicates, and there is a duplicate + # in FIBONACCI. So the output should be 4 times the size of FIBONACCI minus 1 + assert oversample.size == n_os*(FIBONACCI.size - 1) - (n_os-1) + assert oversample.min() == FIBONACCI.min() + assert oversample.max() == FIBONACCI.max() + + # test whether np.interp could have been used instead + grid = np.arange(0, FIBONACCI.size, 1/n_os) + wls = np.unique(np.interp(grid, np.arange(FIBONACCI.size), FIBONACCI)) + assert np.allclose(oversample, wls) + + +@pytest.mark.parametrize("os_factor", [1,2,5]) +def test_oversample_irregular(os_factor): + """Test oversampling to a grid with irregular spacing""" + # oversampling function removes duplicates, + # this is tested in previous test, and just complicates counting for this test + # for FIBONACCI, unique is just removing zeroth element + fib_unq = np.unique(FIBONACCI) + n_os = np.ones((fib_unq.size-1,), dtype=int) + n_os[2:5] = os_factor + n_os[3] = os_factor*2 + # this gives n_os = [1 1 2 4 2] for os_factor = 2 + + oversample = au.oversample_grid(fib_unq, n_os) + + # test no oversampling was done on the elements where not requested + assert np.allclose(oversample[0:2], fib_unq[0:2]) + assert np.allclose(oversample[-1:], fib_unq[-1:]) + + # test output shape. + assert oversample.size == np.sum(n_os)+1 + + # test that this could have been done easily with np.interp + intervals = 1/n_os + intervals = np.insert(np.repeat(intervals, n_os),0,0) + grid = np.cumsum(intervals) + wls = np.interp(grid, np.arange(fib_unq.size), fib_unq) + assert wls.size == oversample.size + assert np.allclose(oversample, wls) + + # test that n_os shape must match input shape - 1 + with pytest.raises(ValueError): + au.oversample_grid(fib_unq, n_os[:-1]) + + +WAVELENGTHS = np.linspace(1.5, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) +@pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) +def test_extrapolate_grid(wave_range): + + extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) + + assert extrapolated.max() > wave_range[1] + assert extrapolated.min() < wave_range[0] + assert np.all(extrapolated[1:] >= extrapolated[:-1]) + + # if interpolation not needed on either side, should return the original + if wave_range == (2.1, 3.9): + assert extrapolated is WAVELENGTHS + + +def test_extrapolate_catch_failed_converge(): + # give wavelengths some non-linearity + wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 + with pytest.raises(RuntimeError): + au._extrapolate_grid(WAVELENGTHS, wave_range, 1) + + +def test_extrapolate_bad_inputs(): + with pytest.raises(ValueError): + au._extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) + with pytest.raises(ValueError): + au._extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) + + +def test_grid_from_map(wave_map, trace_profile): + """Covers expected behavior of grid_from_map, including coverage of a previous bug + where bad wavelengths were not being ignored properly""" + + wave_map = wave_map[0].copy() + wavelengths = wave_map[0][::-1] + trace_profile = trace_profile[0].copy() + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=None) + + assert np.allclose(wave_grid, wavelengths) + + # test custom wave_range + wave_range = [wavelengths[2], wavelengths[-2]+0.01] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert np.allclose(wave_grid, wavelengths[2:-1]) + + # test custom wave_range with extrapolation + wave_range = [wavelengths[2], wavelengths[-1]+1] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert len(wave_grid) > len(wavelengths[2:]) + n_inside = wavelengths[2:].size + assert np.allclose(wave_grid[:n_inside], wavelengths[2:]) + + with pytest.raises(ValueError): + au.grid_from_map(wave_map, trace_profile, wave_range=[0.1,0.2]) + + +def xsinx(x): + return x*np.sin(x) + + +def test_estim_integration_error(): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ + + n = 11 + grid = np.linspace(0, np.pi, n) + err, rel_err = au._estim_integration_err(grid, xsinx) + + assert len(rel_err) == n-1 + assert np.all(rel_err >= 0) + assert np.all(rel_err < 1) + + +@pytest.mark.parametrize("max_iter, rtol", [(1,1e-3), (10, 1e-9), (10, 1e-3), (1, 1e-9)]) +def test_adapt_grid(max_iter, rtol): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ + + input_grid = np.linspace(0, np.pi, 11) + input_grid_diff = input_grid[1] - input_grid[0] + max_grid_size = 100 + grid, is_converged = au._adapt_grid(input_grid, + xsinx, + max_grid_size, + max_iter=max_iter, + rtol=rtol) + + # ensure grid respects max_grid_size and max_iter in all cases + assert len(grid) <= max_grid_size + grid_diff = grid[1:] - grid[:-1] + assert np.min(grid_diff) >= input_grid_diff/(2**max_iter) + + numerical_integral = np.trapz(xsinx(grid), grid) + + # ensure this converges for at least one of our test cases + if max_iter == 10 and rtol == 1e-3: + assert is_converged + + if is_converged: + # test error of the answer is smaller than rtol + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + # test that success was a stop condition + assert len(grid) < max_grid_size + + # test stop conditions + elif max_iter == 10: + # ensure hitting max_grid_size returns an array of exactly length max_grid_size + assert len(grid) == max_grid_size + elif max_iter == 1: + # ensure hitting max_iter can stop iteration before max_grid_size reached + assert len(grid) <= 2*len(input_grid) + + +def test_adapt_grid_bad_inputs(): + with pytest.raises(ValueError): + # input grid larger than max_grid_size + au._adapt_grid(np.array([1,2,3]), xsinx, 2) + + +def test_trim_grids(): + + grid_range = (-3, 3) + grid0 = np.linspace(-3, 0, 4) # kept entirely. + grid1 = np.linspace(-3, 0, 16) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(0, 3, 5) # kept from 0 to 3 + grid3 = np.linspace(-4, 4, 5) # removed entirely. Outside of grid_range and the rest is superseded + + all_grids = [grid0, grid1, grid2, grid3] + trimmed_grids = au._trim_grids(all_grids, grid_range) + + assert len(trimmed_grids) == len(all_grids) + assert trimmed_grids[0].size == grid0.size + assert trimmed_grids[1].size == 0 + assert trimmed_grids[2].size == grid2.size + assert trimmed_grids[3].size == 0 + + +def test_make_combined_adaptive_grid(): + """see also tests of _adapt_grid and _trim_grids for more detailed tests""" + + grid_range = (0, np.pi) + grid0 = np.linspace(0, np.pi/2, 6) # kept entirely. + grid1 = np.linspace(0, np.pi/2, 15) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(np.pi/2, np.pi, 11) # kept from pi/2 to pi + + # purposely make same lower index for grid2 as upper index for grid0 to test uniqueness of output + + all_grids = [grid0, grid1, grid2] + all_estimate = [xsinx, xsinx, xsinx] + + rtol = 1e-3 + combined_grid = au.make_combined_adaptive_grid(all_grids, all_estimate, grid_range, + max_iter=10, rtol=rtol, max_total_size=100) + + numerical_integral = np.trapz(xsinx(combined_grid), combined_grid) + + assert np.unique(combined_grid).size == combined_grid.size + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + + +def test_throughput_soss(): + + wavelengths = np.linspace(2,5,10) + throughputs = np.ones_like(wavelengths) + interpolator = au.ThroughputSOSS(wavelengths, throughputs) + + # test that it returns 1 for all wavelengths inside range + interp = interpolator(wavelengths) + assert np.allclose(interp[1:-1], throughputs[1:-1]) + assert interp[0] == 0 + assert interp[-1] == 0 + + # test that it returns 0 for all wavelengths outside range + wavelengths_outside = np.linspace(1,1.5,5) + interp = interpolator(wavelengths_outside) + assert np.all(interp == 0) + + # test ValueError raise for shape mismatch + with pytest.raises(ValueError): + au.ThroughputSOSS(wavelengths, throughputs[:-1]) + + +def test_webb_kernel(webb_kernels, wave_map): + + wave_trace = wave_map[0][0] + min_trace, max_trace = np.min(wave_trace), np.max(wave_trace) + kern = webb_kernels[0] + + # basic ensure that the input is stored and shapes + assert kern.wave_kernels.shape == kern.kernels.shape + + # test that pixels and wave_kernels are both monotonic + assert np.all(np.diff(kern.pixels) > 0) + assert np.all(np.diff(kern.wave_kernels) > 0) + + # test that pixels is mirrored around the center and has zero at center + assert np.allclose(kern.pixels + kern.pixels[::-1], 0) + assert kern.pixels[kern.pixels.size//2] == 0 + + # test that wave_center has same shape as wavelength axis of wave_kernel + # but contains values that are in wave_trace + assert kern.wave_center.size == kern.wave_kernels.shape[1] + assert all(np.isin(kern.wave_center, wave_trace)) + + # test min value + assert kern.min_value > 0 + assert np.isin(kern.min_value, kern.kernels) + assert isinstance(kern.min_value, float) + + # test the polynomial fit has the proper shape. hard-coded to a first-order, i.e., linear fit + # since the throughput is constant in wavelength, the slopes should be close to zero + # and the y-intercepts should be close to kern.wave_center + # especially with so few points. just go with 10 percent, should catch egregious changes + assert kern.poly.shape == (kern.wave_kernels.shape[1], 2) + assert np.allclose(kern.poly[:,0], 0, atol=1e-1) + assert np.allclose(kern.poly[:,1], kern.wave_center, atol=1e-1) + + # test interpolation function, which takes in a pixel and a wavelength and returns a throughput + # this should return the triangle function at all wavelengths and zero outside range + pix_half = kern.n_pix//2 + wl_test = np.linspace(min_trace, max_trace, 10) + pixels_test = np.array([-pix_half-1, 0, pix_half, pix_half+1]) + + data_in = kern.kernels[:,0] + m = kern.min_value + expected = np.array([m, np.max(data_in), m, m]) + + interp = kern.f_ker(pixels_test, wl_test) + assert interp.shape == (pixels_test.size, wl_test.size) + diff = interp[:,1:] - interp[:,:-1] + assert np.allclose(diff, 0) + assert np.allclose(interp[:,0], expected, rtol=1e-3) + + # call the kernel object directly + # this takes a wavelength and a central wavelength of the kernel, + # then converts to pixels to use self.f_ker internally + kern_val = kern(wl_test, wl_test) + assert kern(wl_test, wl_test).ndim == 1 + # ignore edge effects + assert np.allclose(kern_val[1:-1], np.max(data_in)) + + # both inputs need to be same shape + with pytest.raises(ValueError): + kern(wl_test, wl_test[:-1]) + + +def test_finite_first_diff(): + + wave_grid = np.linspace(0, 2*np.pi, 100) + test_0 = np.ones_like(wave_grid) + test_sin = np.sin(wave_grid) + + first_d = au.finite_first_d(wave_grid) + assert first_d.size == (wave_grid.size - 1)*2 + + # test trivial example returning zeros for constant + f0 = first_d.dot(test_0) + assert np.allclose(f0, 0) + + # test derivative of sin returns cos + wave_between = (wave_grid[1:] + wave_grid[:-1])/2 + f_sin = first_d.dot(test_sin) + assert np.allclose(f_sin, np.cos(wave_between), atol=1e-3) + + +def test_get_c_matrix(kernels_unity, webb_kernels, wave_grid): + """See also test_fct_to_array and test_sparse_c for more detailed tests + of functions called by this one""" + + # only need to test one order + kern = webb_kernels[0] + matrix = au.get_c_matrix(kern, wave_grid, i_bounds=None) + + # ensure proper shape + assert matrix.shape == (wave_grid.size, wave_grid.size) + assert matrix.dtype == np.float64 + + # ensure normalized + assert matrix.sum() == matrix.shape[0] + + # test where input kernel is a 2-D array instead of callable + i_bounds = [0, len(wave_grid)] + kern_array = au._fct_to_array(kern, wave_grid, i_bounds, 1e-5) + matrix_from_array = au.get_c_matrix(kern_array, wave_grid, i_bounds=i_bounds) + assert np.allclose(matrix.toarray(), matrix_from_array.toarray()) + + # test where input kernel is size 1 + kern_unity = kernels_unity[0] + matrix_from_unity = au.get_c_matrix(kern_unity, wave_grid, i_bounds=i_bounds) + assert matrix_from_unity.shape == (wave_grid.size, wave_grid.size) + + + # test where i_bounds is not None + i_bounds = [10, wave_grid.size-10] + matrix_ibnds = au.get_c_matrix(kern, wave_grid, i_bounds=i_bounds) + expected_shape = (wave_grid[i_bounds[0]:i_bounds[1]].size, wave_grid.size) + assert matrix_ibnds.shape == expected_shape + + # Test invalid kernel input (wrong dimensions) + with pytest.raises(ValueError): + kern_array_bad = kern_array[np.newaxis, ...] + au.get_c_matrix(kern_array_bad, wave_grid, i_bounds=i_bounds) + + # Test invalid kernel input (odd shape) + with pytest.raises(ValueError): + kern_array_bad = kern_array[1:,1:] + au.get_c_matrix(kern_array_bad, wave_grid, i_bounds=i_bounds) diff --git a/jwst/extract_1d/soss_extract/tests/test_pastasoss.py b/jwst/extract_1d/soss_extract/tests/test_pastasoss.py new file mode 100644 index 0000000000..e46bae3119 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_pastasoss.py @@ -0,0 +1,99 @@ +import pytest +import numpy as np + +from jwst.extract_1d.soss_extract.pastasoss import ( + _get_wavelengths, _find_spectral_order_index, _get_soss_traces, _extrapolate_to_wavegrid, +) + +from .conftest import TRACE_END_IDX, PWCPOS, WAVE_BNDS_O1, WAVE_BNDS_O2 + + +"""Test coverage for the helper functions in pastasoss.py""" + +def test_wavecal_models(refmodel): + + wave_bnds = [WAVE_BNDS_O1, WAVE_BNDS_O2] + for order in [1,2]: + idx = order-1 + bnds = wave_bnds[idx] + x = np.arange(0, TRACE_END_IDX[idx]+1) + wavelengths = _get_wavelengths(refmodel, x, PWCPOS, order) + + # check shapes + assert wavelengths.shape == x.shape + assert np.isclose(wavelengths[0], bnds[0]) + assert np.isclose(wavelengths[-1], bnds[1]) + + # ensure unique and descending + diff = wavelengths[1:] - wavelengths[:-1] + assert np.all(diff < 0) + + +def test_rotate(): + + # TODO: add meaningful tests of rotate + pass + + +def test_find_spectral_order_index(refmodel): + """TODO: why doesn't this raise an error when order is not recognized? + Surely it's a bad idea to have the index set to -1?""" + for order in [1,2]: + idx = _find_spectral_order_index(refmodel, order) + assert idx == order-1 + + for order in [0, "bad", None]: + with pytest.raises(ValueError): + _find_spectral_order_index(refmodel, order) + + +def test_get_soss_traces(refmodel): + + for order in ["1","2"]: + idx = int(order)-1 + for subarray in ["SUBSTRIP96", "SUBSTRIP256"]: + order_out, x_new, y_new, wavelengths = _get_soss_traces( + refmodel, + PWCPOS, + order, + subarray) + + assert str(order_out) == order + # since always interpolated back to original x, x_new should equal x + x_in, y_in = refmodel.traces[idx].trace.T.copy() + assert np.allclose(x_new, x_in) + # and wavelengths are same as what you get from _get_wavelengths on x_in + wave_expected = _get_wavelengths(refmodel, x_in, PWCPOS, int(order)) + assert np.allclose(wavelengths, wave_expected) + + # the y coordinate is the tricky one. it was rotated by pwcpos - refmodel.meta.pwcpos_cmd + # about pivot_x, pivot_y + assert y_new.shape == wavelengths.shape + # TODO: add meaningful tests of y + + +def test_extrapolate_to_wavegrid(refmodel): + + wavemin = 0.5 + wavemax = 5.5 + nwave = 501 + wave_grid = np.linspace(wavemin, wavemax, nwave) + + # only test first order + x = np.arange(0, TRACE_END_IDX[0]+1) + wl = _get_wavelengths(refmodel, x, PWCPOS, 1) + + # first ensure test setup gives all wl in wave_grid + # floating-point precision issues make np.around calls necessary + assert np.all(np.isin(np.around(wl,5), np.around(wave_grid,5))) + + x_extrap = _extrapolate_to_wavegrid(wave_grid, wl, x) + assert x_extrap.shape == wave_grid.shape + + # test that all x in x_extrap + assert np.all(np.isin(np.around(x,5), np.around(x_extrap,5))) + + # test extrapolated slope is same as input slope, since these are linear + m_extrap = (x_extrap[-1] - x_extrap[0])/(wave_grid[-1] - wave_grid[0]) + m = (x[-1] - x[0])/(wl[-1] - wl[0]) + assert np.isclose(m_extrap, m) diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py new file mode 100644 index 0000000000..da6f074cc7 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py @@ -0,0 +1,117 @@ +import pytest +import numpy as np + +from .conftest import DATA_SHAPE +from jwst.extract_1d.soss_extract.soss_boxextract import ( + get_box_weights, box_extract, estim_error_nearest_data) + + +WIDTH = 5.1 + +@pytest.fixture() +def box_weights(trace1d): + + weights_list = [] + for order in [0,1]: + tracex, tracey, wavetrace = trace1d[order] + weights_list.append(get_box_weights(tracey, WIDTH, DATA_SHAPE, tracex)) + return weights_list + + +def test_get_box_weights(trace1d, box_weights): + """ + Order 1 is easy because tracex.size and tracey.size are equal to data_shape[1] + Order 2 tests the case where they are not equal + """ + + for order in [0,1]: + + tracex, tracey, wavetrace = trace1d[order] + weights = box_weights[order] + + # check weights are between zero and 1 + assert weights.shape == DATA_SHAPE + assert np.max(weights) == 1 + assert np.min(weights) == 0 + + weight_sum = np.sum(weights, axis=0) + # some weights are zero because of the trace profile cutoff in order 2 + assert np.sum(weight_sum == 0) == (DATA_SHAPE[1] - tracey.size) + + # check sum of weights across y-axis is width + weight_sum = weight_sum[np.nonzero(weight_sum)] + assert np.allclose(weight_sum, WIDTH) + + # check at least some partial weights exist + assert np.sum(weight_sum == 1) < weight_sum.size + + # TODO: need a test, maybe regtest, for subarray sub96 problem + # see https://github.com/spacetelescope/jwst/issues/8780 + + +def test_box_extract(trace1d, box_weights, imagemodel): + + data, err = imagemodel + mask = np.isnan(data) + + for order in [0,1]: + weights = box_weights[order] + cols, flux, flux_err, npix = box_extract(data, err, mask, weights) + + # test that cols just represents the data + assert np.allclose(cols, np.arange(data.shape[1])) + + # test flux and flux_err are NaN where order 2 is cut off, but have good values elsewhere + xtrace = trace1d[order][0] + for f in [flux, flux_err]: + assert np.sum(~np.isnan(flux)) == xtrace.size + # test npix is zero there too + assert np.count_nonzero(npix) == xtrace.size + + # test that most of npix are equal to width (Not all, because of NaN mask, but NaN fraction) + # is small enough that it should still be the most represented count for such a small width + unique, counts = np.unique(npix, return_counts=True) + assert np.isclose(unique[np.argmax(counts)], WIDTH) + + # TODO: somehow check the fluxes retrieved look like what we would expect from data + # although this is hard because the wavelengths are not extracted here + + # TODO: why does flux have very low values at edges even after cutting by good? + + +def test_estim_error_nearest_data(imagemodel, mask_trace_profile): + + data, err = imagemodel + + for order in [0, 1]: + # this has bad pixels set to 1, ONLY within the spectral trace. + # everything else is zero, i.e., regions outside trace and good data + pix_to_estim = np.zeros(data.shape, dtype="bool") + pix_to_estim[np.isnan(data)] = 1 + pix_to_estim[mask_trace_profile[order] == 1] = 0 + + # this has bad pixels set to 0, and regions outside trace set to 0, and good data 1 + valid_pix = ~mask_trace_profile[order] + valid_pix[pix_to_estim] = 0 + + err_out = estim_error_nearest_data(err, data, pix_to_estim, valid_pix) + + # test that all replaced values are positive and no NaNs are left + assert np.sum(np.isnan(err_out)) == 0 + assert np.all(err_out > 0) + + # test that the replaced pixels are not statistical outliers c.f. the other pixels + # replaced_pix = err_out[pix_to_estim] + # original_pix = err_out[valid_pix] + # diff = np.mean(replaced_pix)/np.mean(original_pix) + # assert np.isclose(diff, 1, rtol=0.5) # assert False + + # TODO: why does this fail? + # In both orders, the errors on the replaced pixels are roughly + # half of the errors on the original good pixels + # There are enough replaced pixels here (~30) that small-number statistics cannot account for this + # The reason is because the code chooses the lower error between the two nearest-flux + # data points, and since the errors in our tests are uncorrelated with the flux values, + # this leads to a factor-of-2 decrease + # It's not clear to me that picking the smaller of two error values is the right thing to do + # but that behavior is documented diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_extract.py b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py new file mode 100644 index 0000000000..ad82ca4f6f --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py @@ -0,0 +1,172 @@ +from functools import partial +import pytest +import numpy as np +from stdatamodels.jwst.datamodels import SpecModel, SossWaveGridModel + +from jwst.extract_1d.soss_extract.soss_extract import ( + _model_image, _compute_box_weights, +) +from .conftest import DATA_SHAPE + + +@pytest.fixture +def monkeypatch_setup(monkeypatch, + wave_map, + trace_profile, + throughput, + webb_kernels, + trace1d,): + """Monkeypatch get_ref_file_args and get_trace_1d to return the miniature model detector""" + + def mock_get_ref_file_args(wave, trace, thru, kern, reffiles): + """Return the arrays from conftest instead of querying CRDS""" + return [wave, trace, thru, kern] + + def mock_trace1d(trace, reffiles, order): + """Return the traces from conftest instead of doing math that requires a full-sized detector""" + return trace[int(order)-1] + + monkeypatch.setattr("jwst.extract_1d.soss_extract.soss_extract.get_ref_file_args", + partial(mock_get_ref_file_args, wave_map, trace_profile, throughput, webb_kernels)) + monkeypatch.setattr("jwst.extract_1d.soss_extract.soss_extract._get_trace_1d", + partial(mock_trace1d, trace1d)) + + +# slow because tests multiple Tikhonov factors +@pytest.mark.slow +def test_model_image(monkeypatch_setup, + imagemodel, + detector_mask, + ref_files,): + scidata, scierr = imagemodel + + refmask = np.zeros_like(detector_mask) + box_width = 5.0 + box_weights, wavelengths = _compute_box_weights(ref_files, DATA_SHAPE, box_width) + + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=None, threshold=1e-4, n_os=2, wave_grid=None, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) + + # check output basics, types and shapes + assert len(tracemodels) == 2 + for order in tracemodels: + tm = tracemodels[order] + assert tm.dtype == np.float64 + assert tm.shape == DATA_SHAPE + # should be some nans in the trace model but not all + assert 0 < np.sum(np.isfinite(tm)) < tm.size + for x in [tikfac, logl]: + assert isinstance(x, float) + assert np.isfinite(x) + assert logl < 0 + assert wave_grid.dtype == np.float64 + for spec in spec_list: + assert isinstance(spec, SpecModel) + + factors = np.array([getattr(spec.meta.soss_extract1d, "factor", np.nan) for spec in spec_list]) + chi2s = np.array([getattr(spec.meta.soss_extract1d, "chi2", np.nan) for spec in spec_list]) + orders = np.array([spec.spectral_order for spec in spec_list]) + colors = np.array([spec.meta.soss_extract1d.color_range for spec in spec_list]) + + assert tikfac in factors + + # ensure outputs have the shapes we expect for each order and blue/red + n_good = [] + for order in [1,2]: + for color in ["RED", "BLUE"]: + good = (order == orders) & (color == colors) + # check that there's at least one good spectrum for all valid order-color combinations + if not np.any(good): + assert order == 1 + assert color == "BLUE" + continue + n_good.append(np.sum(good)) + + this_factors = factors[good] + this_chi2s = chi2s[good] + this_spec = np.array(spec_list)[good] + nochi = np.isnan(this_chi2s) + + # _model_single_order is set up so that the final/best spectrum is last in the list + # it lacks chi2 calculations + assert np.sum(nochi) == 1 + assert np.where(nochi)[0][0] == len(this_chi2s) - 1 + + # it represents the best tikhonov factor for that order-color combination + # which is not necessarily the same as the top-level tikfac for the blue part of order 2 + # but it is the same for the red part of order 1 and the red part of order 2 + if color == "RED": + assert this_factors[-1] == tikfac + + # check that the output spectra contain good data + for spec in this_spec: + spec = np.array([[s[0], s[1]] for s in spec.spec_table]) + assert np.sum(np.isfinite(spec)) == spec.size + + + # check that all order-color combinations have the same number of spectra + n_good = np.array(n_good) + assert np.all(n_good >= 1) + assert np.all(n_good - n_good[0] == 0) + + +def test_model_image_tikfac_specified(monkeypatch_setup, + imagemodel, + detector_mask, + ref_files,): + """Ensure spec_list is a single-element list per order if tikfac is specified""" + scidata, scierr = imagemodel + + refmask = np.zeros_like(detector_mask) + box_width = 5.0 + box_weights, wavelengths = _compute_box_weights(ref_files, DATA_SHAPE, box_width) + + tikfac_in = 1e-7 + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=None, + estimate=None, rtol=1e-3, max_grid_size=1000000, +) + # check that spec_list is a single-element list per order in this case + assert len(spec_list) == 3 + assert tikfac == tikfac_in + + +def test_model_image_wavegrid_specified(monkeypatch_setup, + imagemodel, + detector_mask, + ref_files,): + """Ensure wave_grid is used if specified. + Also specify tikfac because it makes the code run faster to not have to re-derive it. + + Note the failure with SossWaveGridModel passed as input. What should be done about that? + """ + scidata, scierr = imagemodel + + refmask = np.zeros_like(detector_mask) + box_width = 5.0 + box_weights, wavelengths = _compute_box_weights(ref_files, DATA_SHAPE, box_width) + + tikfac_in = 1e-7 + # test np.array input + wave_grid_in = np.linspace(1.0, 2.5, 100) + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=wave_grid_in, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) + assert np.allclose(wave_grid, wave_grid_in) + + # test SossWaveGridModel input + # the docs on main say this works, but I don't think it does even on main + with pytest.raises(ValueError): + wave_grid_in = SossWaveGridModel() + wave_grid_in.wavegrid = np.linspace(1.0, 2.5, 100) + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=wave_grid_in, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_syscor.py b/jwst/extract_1d/soss_extract/tests/test_soss_syscor.py new file mode 100644 index 0000000000..6576c878c3 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_soss_syscor.py @@ -0,0 +1,56 @@ +import pytest +import numpy as np + +from jwst.extract_1d.soss_extract.soss_syscor import ( + soss_background, + make_background_mask, +) + + +def test_soss_background(imagemodel, detector_mask, mask_trace_profile): + + data, err = imagemodel + bkg_mask = ~mask_trace_profile[0] | ~mask_trace_profile[1] | detector_mask + + data_bkg, col_bkg = soss_background(data, detector_mask, bkg_mask) + assert data_bkg.shape == data.shape + assert col_bkg.size == data.shape[1] + + # check background now has mean zero + mean_bkg = np.mean(data_bkg[~bkg_mask]) + assert np.isclose(mean_bkg, 0.0) + + # check col_bkg are at least close to the non-sigma-clipped version which is much easier to calculate + # For the test case, there are no outliers so this should be quite a close match + data[bkg_mask] = np.nan + col_bkg_unclipped = np.nanmean(data, axis=0) + assert np.allclose(col_bkg, col_bkg_unclipped) + + +def test_make_background_mask(): + + rng = np.random.default_rng(seed=42) + for sub in [96, 256, 2048]: + + shape = (sub, 2048) + width = int(sub/4) + data = rng.normal(0.0, 1.0, shape) + + mask = make_background_mask(data, width) + + if sub == 96: + expected_bad_frac = 1/4 + else: + expected_bad_frac = 1/2 + + bad_frac = np.sum(mask)/mask.size + # test that bad fraction is computed properly for all modes + assert np.isclose(bad_frac, expected_bad_frac) + + # test that mask=True is the high-flux pixels + assert np.mean(data[mask]) > np.mean(data) + + # test unrecognized shape + with pytest.raises(ValueError): + data = rng.normal(0.0, 1.0, (40, 2048)) + make_background_mask(data, width) diff --git a/jwst/regtest/test_niriss_soss.py b/jwst/regtest/test_niriss_soss.py index 451b15adf7..c6d2951c43 100644 --- a/jwst/regtest/test_niriss_soss.py +++ b/jwst/regtest/test_niriss_soss.py @@ -154,3 +154,31 @@ def test_extract1d_null_order2(rtdata_module, run_extract1d_null_order2, fitsdif diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) assert diff.identical, diff.report() + + +@pytest.fixture(scope='module') +def run_spec2_substrip96(rtdata_module): + """Run stage 2 pipeline on substrip96 data. + Solving for the optimal Tikhonov factor is time-consuming, and the code to do + so is identical between substrip96 and substrip256 data. Therefore just set + it to a reasonable value here.""" + rtdata = rtdata_module + rtdata.get_data("niriss/soss/jw03596001001_03102_00001-seg001_nis_ints0-2_rateints.fits") + args = ["calwebb_spec2", rtdata.input, + "--steps.extract_1d.soss_tikfac=1.0e-16",] + Step.from_cmdline(args) + + +@pytest.mark.bigdata +@pytest.mark.parametrize("suffix", ["calints", "x1dints"]) +def test_spec2_substrip96(rtdata_module, run_spec2_substrip96, fitsdiff_default_kwargs, suffix): + """Regression test of tso-spec2 pipeline performed on NIRISS SOSS data.""" + rtdata = rtdata_module + + output = f"jw03596001001_03102_00001-seg001_nis_ints0-2_{suffix}.fits" + rtdata.output = output + + rtdata.get_truth(f"truth/test_niriss_soss_stages/{output}") + + diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) + assert diff.identical, diff.report()