From a1a6ca3106214c8aca8a7b1caaba8482ca771e80 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 25 Oct 2024 14:39:30 -0400 Subject: [PATCH 01/35] added unit tests for wv_map_bounds and arange_2d --- jwst/extract_1d/soss_extract/atoca.py | 7 +- jwst/extract_1d/soss_extract/atoca_utils.py | 393 +++++------------- jwst/extract_1d/soss_extract/soss_extract.py | 14 +- .../extract_1d/soss_extract/tests/__init__.py | 0 .../soss_extract/tests/test_atoca_utils.py | 81 ++++ 5 files changed, 196 insertions(+), 299 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/__init__.py create mode 100644 jwst/extract_1d/soss_extract/tests/test_atoca_utils.py diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 1331c7a7b3..8fa6908a5b 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -376,7 +376,6 @@ def update_kernels(self, kernels, c_kwargs): self.kernels = kernels_new - return def get_mask_wave(self, i_order): """Generate mask bounded by limits of wavelength grid @@ -1716,9 +1715,9 @@ def get_w(self, i_order): # else, need hi_i + 1 k_last[~cond & ~ma] = hi[~cond & ~ma] + 1 - # Generate array of all k_i. Set to -1 if not valid - k_n, bad = atoca_utils.arange_2d(k_first, k_last + 1, dtype=int) - k_n[bad] = -1 + # Generate array of all k_i. Set to max value of uint16 if not valid + k_n = atoca_utils.arange_2d(k_first, k_last + 1) + bad = k_n == np.iinfo(k_n.dtype).max # Number of valid k per pixel n_k = np.sum(~bad, axis=-1) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 19a0144539..4960d1b33d 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -7,7 +7,7 @@ """ import numpy as np -from scipy.sparse import find, diags, csr_matrix +from scipy.sparse import diags, csr_matrix from scipy.sparse.linalg import spsolve from scipy.interpolate import interp1d, RectBivariateSpline, Akima1DInterpolator from scipy.optimize import minimize_scalar, brentq @@ -16,38 +16,26 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -# ============================================================================== -# Code for generating indices on the oversampled wavelength grid. -# ============================================================================== - -def arange_2d(starts, stops, dtype=None): - """Create a 2D array containing a series of ranges. The ranges do not have - to be of equal length. +def arange_2d(starts, stops): + """ + Code for generating indices on the oversampled wavelength grid. + Creates a 2D array containing a series of ranges. + The ranges do not have to be of equal length. Parameters ---------- - starts : int or array[int] + starts : array[int] Start values for each range. - stops : int or array[int] + stops : array[int] End values for each range. - dtype : str - Type of the output values. Returns ------- - out : array[int] - 2D array of ranges. - mask : array[bool] - Mask indicating valid elements. + out : array[uint16] + 2D array of ranges with invalid values set to max uint16 value, 65535 """ - - # Ensure starts and stops are arrays. - starts = np.asarray(starts) - stops = np.asarray(stops) - - # Check input for starts and stops is valid. - if starts.shape != stops.shape and starts.shape != (): + if starts.shape != stops.shape: msg = ('Shapes of starts and stops are not compatible, ' 'they must either have the same shape or starts must be scalar.') log.critical(msg) @@ -58,34 +46,25 @@ def arange_2d(starts, stops, dtype=None): log.critical(msg) raise ValueError(msg) - # If starts was given as a scalar match its shape to stops. - if starts.shape == (): - starts = starts * np.ones_like(stops) - # Compute the length of each range. lengths = (stops - starts).astype(int) - # Initialize the output arrays. + # Initialize the output arrays with invalid value nrows = len(stops) ncols = np.amax(lengths) - out = np.ones((nrows, ncols), dtype=dtype) - mask = np.ones((nrows, ncols), dtype='bool') + out = np.ones((nrows, ncols), dtype=np.uint16)*np.iinfo(np.uint16).max # Compute the indices. for irow in range(nrows): out[irow, :lengths[irow]] = np.arange(starts[irow], stops[irow]) - mask[irow, :lengths[irow]] = False - - return out, mask - - -# ============================================================================== -# Code for converting to a sparse matrix and back. -# ============================================================================== + return out def sparse_k(val, k, n_k): - """Transform a 2D array `val` to a sparse matrix. + """ + TODO: ensure test coverage. Probably sufficient to have test for compute_weights + in atoca.py + Transform a 2D array `val` to a sparse matrix. Parameters ---------- @@ -114,56 +93,11 @@ def sparse_k(val, k, n_k): col = k[k >= 0] data = val[k >= 0] - mat = csr_matrix((data, (row, col)), shape=(n_i, n_k)) - - return mat - - -def unsparse(matrix, fill_value=np.nan): - """Convert a sparse matrix to a 2D array of values and a 2D array of position. - - Parameters - ---------- - matrix : csr_matrix - The input sparse matrix. - fill_value : float - Value to fill 2D array for undefined positions; default to np.nan - - Returns - ------ - out : 2d array - values of the matrix. The shape of the array is given by: - (matrix.shape[0], maximum number of defined value in a column). - col_out : 2d array - position of the columns. Same shape as `out`. - """ - - col, row, val = find(matrix.T) - n_row, n_col = matrix.shape - - good_rows, counts = np.unique(row, return_counts=True) - - # Define the new position in columns - i_col = np.indices((n_row, counts.max()))[1] - i_col = i_col[good_rows] - i_col = i_col[i_col < counts[:, None]] - - # Create outputs and assign values - col_out = np.ones((n_row, counts.max()), dtype=int) * -1 - col_out[row, i_col] = col - out = np.ones((n_row, counts.max())) * fill_value - out[row, i_col] = val - - return out, col_out - - -# ============================================================================== -# Code for building wavelength grids. -# ============================================================================== + return csr_matrix((data, (row, col)), shape=(n_i, n_k)) def get_wave_p_or_m(wave_map, dispersion_axis=1): - """ Compute upper and lower boundaries of a pixel map, + """Compute upper and lower boundaries of a pixel map, given the pixel central value. Parameters ---------- @@ -209,6 +143,19 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): Wavelength of top edge for each pixel wave_bottom : array[float] Wavelength of bottom edge for each pixel + + Notes + ----- + Handling of invalid pixels may lead to unexpected results as follows: + Bad pixels are completely ignored when computing pixel-to-pixel differences, so + wv_map=[2,4,6,NaN,NaN,12,14,16] will give wave_top=[1,3,5,0,0,9,13,15] + because the difference at index 5 was calculated as 12-(12-6)/2=9, + i.e., as though index 2 and 5 were next to each other. + A human (or a smarter linear interpolation) would figure out the slope is 2 and + determine the value of wave_top[5] should most likely be 11. + + TODO: the above note probably doesn't matter in practice, but see if SOSS team wants + to change this behavior. """ if dispersion_axis == 1: # Simpler to use transpose @@ -222,25 +169,25 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): wave_top = np.zeros_like(wave_map) wave_bottom = np.zeros_like(wave_map) - (n_row, n_col) = wave_map.shape + # for loop is needed to compute diff in just one spatial direction + # while skipping invalid values- not trivial to do with array comprehension even + # using masked arrays + n_col = wave_map.shape[1] for idx in range(n_col): wave_col = wave_map[:, idx] # Compute the change in wavelength for valid cols - idx_valid = np.isfinite(wave_col) - idx_valid &= (wave_col >= 0) + idx_valid = np.isfinite(wave_col) & (wave_col >= 0) wv_col_valid = wave_col[idx_valid] - delta_wave = np.diff(wv_col_valid) + delta_wave = np.diff(wv_col_valid) / 2 - # Init values - wv_col_top = np.zeros_like(wv_col_valid) - wv_col_bottom = np.zeros_like(wv_col_valid) + # handle edge effects using a constant-difference rule + delta_wave_top = np.insert(delta_wave,0,delta_wave[0]) + delta_wave_bottom = np.append(delta_wave,delta_wave[-1]) # Compute the wavelength values on the top and bottom edges of each pixel. - wv_col_top[1:] = wv_col_valid[:-1] + delta_wave / 2 # TODO check this logic. - wv_col_top[0] = wv_col_valid[0] - delta_wave[0] / 2 - wv_col_bottom[:-1] = wv_col_valid[:-1] + delta_wave / 2 - wv_col_bottom[-1] = wv_col_valid[-1] + delta_wave[-1] / 2 + wv_col_top = wv_col_valid - delta_wave_top + wv_col_bottom = wv_col_valid + delta_wave_bottom wave_top[idx_valid, idx] = wv_col_top wave_bottom[idx_valid, idx] = wv_col_bottom @@ -253,7 +200,10 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): def oversample_grid(wave_grid, n_os=1): - """Create an oversampled version of the input 1D wavelength grid. + """ + TODO: can this be replaced by np.interp or similar? + + Create an oversampled version of the input 1D wavelength grid. Parameters ---------- @@ -304,13 +254,18 @@ def oversample_grid(wave_grid, n_os=1): wave_grid_os = np.concatenate([wave_grid_os, sub_grid]) # Take only unique values and sort them. - wave_grid_os = np.unique(wave_grid_os) - - return wave_grid_os + return np.unique(wave_grid_os) def extrapolate_grid(wave_grid, wave_range, poly_ord): - """Extrapolate the given 1D wavelength grid to cover a given range of values + """ + TODO: It looks like these while loops and if statements can be removed and + replaced by something that operates on the whole array at once + e.g. the p.linspace function in numpy.polynomial.polynomial (see link below) + TODO: np.polyfit is considered legacy now, should be replaced by + https://numpy.org/doc/stable/reference/routines.polynomials-package.html + + Extrapolate the given 1D wavelength grid to cover a given range of values by fitting the derivative with a polynomial of a given order and using it to compute subsequent values at both ends of the grid. @@ -373,9 +328,7 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord): grid_right = np.unique(grid_right) # Combine the extrapolated sections with the original grid. - wave_grid_ext = np.concatenate([grid_left, wave_grid, grid_right]) - - return wave_grid_ext + return np.concatenate([grid_left, wave_grid, grid_right]) def _grid_from_map(wave_map, trace_profile): @@ -469,9 +422,7 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): out = grid # Apply oversampling - grid_os = oversample_grid(out, n_os=n_os) - - return grid_os + return oversample_grid(out, n_os=n_os) def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=None): @@ -599,7 +550,10 @@ def _trim_grids(all_grids, grid_range=None): def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, max_iter=10, rtol=10e-6, tol=0.0, max_total_size=1000000): - """Return an irregular oversampled grid needed to reach a + """ + TODO: can this be a class? e.g., class AdaptiveGrid? + + Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. The grid is built by subdividing iteratively each intervals that did not reach the required precision. @@ -713,11 +667,7 @@ def _romberg_diff(b, c, k): R(n, m) : float or array[float] Difference between integral estimates of Rombergs method. """ - - tmp = 4.0**k - diff = (tmp * c - b) / (tmp - 1.0) - - return diff + return (4.0**k * c - b) / (4.0**k - 1.0) def _difftrap(fct, intervals, numtraps): @@ -786,114 +736,6 @@ def _difftrap(fct, intervals, numtraps): return ordsum -def get_n_nodes(grid, fct, divmax=10, tol=1.48e-4, rtol=1.48e-4): - """Refine parts of a grid to reach a specified integration precision - based on Romberg integration of a callable function or method. - Returns the number of nodes needed in each intervals of - the input grid to reach the specified tolerance over the integral - of `fct` (a function of one variable). - - Note: This function is based on scipy.integrate.quadrature.romberg. The - difference between it and the scipy version is that it is vectorized to deal - with multiple intervals separately. It also returns the number of nodes - needed to reached the required precision instead of returning the value of - the integral. - - Parameters - ---------- - grid : array[float] - Grid for integration. Each section of this grid is treated as a - separate integral; if grid has length N, N-1 integrals are optimized. - fct : callable - Function to be integrated. - divmax : int - Maximum order of extrapolation. - tol : float - The desired absolute tolerance. - rtol : float - The desired relative tolerance. - - Returns - ------- - n_grid : array[int] - Number of nodes needed on each distinct intervals in the grid to reach - the specified tolerance. - residual : array[float] - Estimate of the error in each intervals. Same length as n_grid. - """ - - # Initialize some variables. - n_intervals = len(grid) - 1 - i_bad = np.arange(n_intervals) - n_grid = np.repeat(-1, n_intervals) - residual = np.repeat(np.nan, n_intervals) - - # Change the 1D grid into a 2D set of intervals. - intervals = np.array([grid[:-1], grid[1:]]) - intrange = np.diff(grid) - err = np.inf - - # First estimate without subdivision. - numtraps = 1 - ordsum = _difftrap(fct, intervals, numtraps) - results = intrange * ordsum - last_row = [results] - - for i_div in range(1, divmax + 1): - - # Increase the number of trapezoids by factors of 2. - numtraps *= 2 - - # Evaluate trapz integration for intervals that are not converged. - ordsum += _difftrap(fct, intervals[:, i_bad], numtraps) - row = [intrange[i_bad] * ordsum / numtraps] - - # Compute Romberg for each of the computed sub grids. - for k in range(i_div): - romb_k = _romberg_diff(last_row[k], row[k], k + 1) - row = np.vstack([row, romb_k]) - - # Save R(n,n) and R(n-1, n-1) from Romberg method. - results = row[i_div] - lastresults = last_row[i_div - 1] - - # Estimate error. - err = np.abs(results - lastresults) - - # Find intervals that are converged. - conv = (err < tol) | (err < rtol * np.abs(results)) - - # Save number of nodes for these intervals. - n_grid[i_bad[conv]] = numtraps - - # Save residuals. - residual[i_bad] = err - - # Stop if all intervals have converged. - if conv.all(): - break - - # Find intervals not converged. - i_bad = i_bad[~conv] - - # Save last_row and ordsum for the next iteration for non-converged - # intervals. - ordsum = ordsum[~conv] - last_row = row[:, ~conv] - - else: - # Warn that convergence is not reached everywhere. - log.warning(f"divmax {divmax} exceeded. Latest difference = {err.max()}") - - # Make sure all values of n_grid where assigned during the process. - if (n_grid == -1).any(): - msg = f"Values where not assigned at grid position: {np.where(n_grid == -1)}" - log.critical(msg) - raise ValueError(msg) - - return n_grid, residual - - def estim_integration_err(grid, fct): """Estimate the integration error on each intervals of the grid using 1rst order Romberg integration. @@ -988,7 +830,7 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): # Init some flags max_size_reached = (grid.size >= max_grid_size) - # Iterate until precision is reached of max_iter + # Iterate until precision is reached or max_iter for _ in range(max_iter): # Estimate error using Romberg integration @@ -1035,30 +877,6 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): return grid, is_converged -# ============================================================================== -# Code for handling the throughput and kernels. -# ============================================================================== - - -class ThroughputSOSS(interp1d): - - def __init__(self, wavelength, throughput): - """Create an instance of scipy.interpolate.interp1d to handle the - throughput values. - - Parameters - ---------- - wavelength : array[float] - A wavelength array. - throughput : array[float] - The throughput values corresponding to the wavelengths. - """ - - # Interpolate - super().__init__(wavelength, throughput, kind='cubic', fill_value=0, - bounds_error=False) - - class WebbKernel: # TODO could probably be cleaned-up somewhat, may need further adjustment. def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels may need to be flipped? @@ -1258,13 +1076,11 @@ def __call__(self, wave, wave_c): return webbker -# ============================================================================== -# Code for building the convolution matrix (c matrix). -# ============================================================================== - - def gaussians(x, x0, sig, amp=None): - """Gaussian function + """ + TODO: can this be replaced by something in scipy or numpy, + e.g., scipy.signal.windows.gaussian? + Gaussian function Parameters ---------- @@ -1282,14 +1098,9 @@ def gaussians(x, x0, sig, amp=None): values : array[float] Array of gaussian values for input x. """ - - # Amplitude term if amp is None: amp = 1. / np.sqrt(2. * np.pi * sig**2.) - - values = amp * np.exp(-0.5 * ((x - x0) / sig) ** 2.) - - return values + return amp * np.exp(-0.5 * ((x - x0) / sig) ** 2.) def fwhm2sigma(fwhm): @@ -1305,14 +1116,11 @@ def fwhm2sigma(fwhm): sigma : float Standard deviation of a gaussian. """ - - sigma = fwhm / np.sqrt(8. * np.log(2.)) - - return sigma + return fwhm / np.sqrt(8. * np.log(2.)) def to_2d(kernel, grid_range): - """ Build a 2d kernel array with a constant 1D kernel (input) + """Build a 2d kernel array with a constant 1D kernel (input) Parameters ---------- @@ -1335,9 +1143,7 @@ def to_2d(kernel, grid_range): n_k_c = b - a # Return a 2D array with this length - kernel_2d = np.tile(kernel, (n_k_c, 1)).T - - return kernel_2d + return np.tile(kernel, (n_k_c, 1)).T def _get_wings(fct, grid, h_len, i_a, i_b): @@ -1408,7 +1214,10 @@ def _get_wings(fct, grid, h_len, i_a, i_b): def trpz_weight(grid, length, shape, i_a, i_b): - """Compute weights due to trapezoidal integration + """ + TODO: add to some integration class? + + Compute weights due to trapezoidal integration Parameters ---------- @@ -1457,7 +1266,10 @@ def trpz_weight(grid, length, shape, i_a, i_b): def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): - """Build a compact kernel 2d array based on a kernel function + """ + TODO: can't scipy do this? + + Build a compact kernel 2d array based on a kernel function and a grid to project the kernel Parameters @@ -1775,7 +1587,11 @@ def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, class NyquistKer: - """Define a gaussian convolution kernel at the nyquist + """ + TODO: look into whether custom Gaussian function is needed, or if + something like scipy.ndimage.gaussian_filter1d could be used. + + Define a gaussian convolution kernel at the nyquist sampling. For a given point on the grid x_i, the kernel is given by a gaussian with FWHM = n_sampling * (dx_(i-1) + dx_i) / 2. @@ -1835,11 +1651,6 @@ def __call__(self, x, x0): return gaussians(x, x0, sig) -# ============================================================================== -# Code for doing Tikhonov regularisation. -# ============================================================================== - - def finite_diff(x): """Returns the finite difference matrix operator based on x. @@ -1892,9 +1703,7 @@ def finite_second_d(grid): second_d = finite_diff(grid[:-1]).dot(first_d) # don't forget the delta lambda - second_d = diags(1. / d_grid[:-1]).dot(second_d) - - return second_d + return diags(1. / d_grid[:-1]).dot(second_d) def finite_first_d(grid): @@ -1920,13 +1729,14 @@ def finite_first_d(grid): d_grid = d_matrix.dot(grid) # First derivative operator - first_d = diags(1. / d_grid).dot(d_matrix) - - return first_d + return diags(1. / d_grid).dot(d_matrix) def get_tikho_matrix(grid, n_derivative=1, d_grid=True, estimate=None, pwr_law=0): - """Wrapper to return the tikhonov matrix given a grid and the derivative degree. + """ + TODO: can all this Tikhonov stuff go into the classes? + + Wrapper to return the tikhonov matrix given a grid and the derivative degree. Parameters ---------- @@ -2099,9 +1909,7 @@ def _get_interp_idx_array(idx, relative_range, max_length): abs_range[-1] = np.min([abs_range[-1], max_length]) # Convert to slice - out = np.arange(*abs_range, 1) - - return out + return np.arange(*abs_range, 1) def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): @@ -2315,12 +2123,12 @@ def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): self[chi2_type] except KeyError: self[chi2_type] = self.compute_chi2(loss=loss) -# # Save different loss function for chi2 -# self['chi2_soft_l1'] = self.compute_chi2(loss='soft_l1') -# self['chi2_cauchy'] = self.compute_chi2(loss='cauchy') + def compute_chi2(self, tests=None, n_points=None, loss='linear'): - """ Calculates the reduced chi squared statistic + """ + TODO: is there a scipy builtin that does this? + Calculates the reduced chi squared statistic Parameters ---------- @@ -2358,7 +2166,9 @@ def compute_chi2(self, tests=None, n_points=None, loss='linear'): return chi2 def get_chi2_derivative(self, key=None): - """ Compute derivative of the chi2 with respect to log10(factors) + """ + TODO: is there a scipy builtin that does this? + Compute derivative of the chi2 with respect to log10(factors) Parameters ---------- @@ -2506,6 +2316,9 @@ def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, class Tikhonov: """ + TODO: can we avoid all of this by using scipy.optimize.least_squares + like this? https://stackoverflow.com/questions/62768131/how-to-add-tikhonov-regularization-in-scipy-optimize-least-squares + Tikhonov regularization to solve the ill-posed problem A.x = b, where A is accidentally singular or close to singularity. Tikhonov regularization adds a regularization term in the equation and aim to minimize the diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 77a0284463..9d780f0fdd 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -1,6 +1,7 @@ import logging import numpy as np +from scipy.interpolate import make_interp_spline from scipy.interpolate import UnivariateSpline, CubicSpline from stdatamodels.jwst import datamodels @@ -12,7 +13,7 @@ from .soss_syscor import make_background_mask, soss_background from .atoca import ExtractionEngine, MaskOverlapError -from .atoca_utils import (ThroughputSOSS, WebbKernel, grid_from_map, +from .atoca_utils import (WebbKernel, grid_from_map, make_combined_adaptive_grid, get_wave_p_or_m, oversample_grid) from .soss_boxextract import get_box_weights, box_extract, estim_error_nearest_data from .pastasoss import get_soss_wavemaps, XTRACE_ORD1_LEN, XTRACE_ORD2_LEN @@ -80,10 +81,13 @@ def get_ref_file_args(ref_files): for i, throughput in enumerate(pastasoss_ref.throughputs): throughput_index_dict[throughput.spectral_order] = i - throughput_o1 = ThroughputSOSS(pastasoss_ref.throughputs[throughput_index_dict[1]].wavelength[:], - pastasoss_ref.throughputs[throughput_index_dict[1]].throughput[:]) - throughput_o2 = ThroughputSOSS(pastasoss_ref.throughputs[throughput_index_dict[2]].wavelength[:], - pastasoss_ref.throughputs[throughput_index_dict[2]].throughput[:]) + # TODO: check that replacing legacy interp1d with make_interp_spline works right + throughput_o1 = make_interp_spline(pastasoss_ref.throughputs[throughput_index_dict[1]].wavelength[:], + pastasoss_ref.throughputs[throughput_index_dict[1]].throughput[:], + k=3, fill_value=0.0, bounds_error=False) + throughput_o2 = make_interp_spline(pastasoss_ref.throughputs[throughput_index_dict[2]].wavelength[:], + pastasoss_ref.throughputs[throughput_index_dict[2]].throughput[:], + k=3, fill_value=0.0, bounds_error=False) # The spectral kernels. speckernel_ref = ref_files['speckernel'] diff --git a/jwst/extract_1d/soss_extract/tests/__init__.py b/jwst/extract_1d/soss_extract/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py new file mode 100644 index 0000000000..d7bf689edf --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -0,0 +1,81 @@ + +import pytest +from jwst.extract_1d.soss_extract import atoca_utils as au +import numpy as np + +def test_arange_2d(): + + starts = np.array([3,4,5]) + stops = np.ones(starts.shape)*7 + out = au.arange_2d(starts, stops) + + bad = 65535 + expected_out = np.array([ + [3,4,5,6], + [4,5,6,bad], + [5,6,bad,bad] + ]) + assert np.allclose(out, expected_out) + + # test bad input catches + starts_wrong_shape = starts[1:] + with pytest.raises(ValueError): + au.arange_2d(starts_wrong_shape, stops) + + stops_too_small = np.copy(stops) + stops_too_small[2] = 4 + with pytest.raises(ValueError): + au.arange_2d(starts, stops_too_small) + +FIBONACCI = np.array([1,1,2,3,5,8,13], dtype=float) +@pytest.fixture(scope="module") +def wave_map(): + wave_map = np.array([ + FIBONACCI, + FIBONACCI+1, + + ]) + wave_map[1,3] = -1 #test skip of bad value + return wave_map + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wv_map_bounds(wave_map, dispersion_axis): + """ + top is the low-wavelength end, bottom is high-wavelength end + """ + if dispersion_axis == 0: + wave_flip = wave_map.T + else: + wave_flip = wave_map + wave_top, wave_bottom = au.get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) + + # flip the results back so we can re-use the same tests + if dispersion_axis == 0: + wave_top = wave_top.T + wave_bottom = wave_bottom.T + + diff = (FIBONACCI[1:]-FIBONACCI[:-1])/2 + diff_lower = np.insert(diff,0,diff[0]) + diff_upper = np.append(diff,diff[-1]) + wave_top_expected = FIBONACCI-diff_lower + wave_bottom_expected = FIBONACCI+diff_upper + + # basic test + assert wave_top.shape == wave_bottom.shape == (2,)+FIBONACCI.shape + assert np.allclose(wave_top[0], wave_top_expected) + assert np.allclose(wave_bottom[0], wave_bottom_expected) + + # test skip bad pixel + assert wave_top[1,3] == 0 + assert wave_bottom[1,3] == 0 + + # test bad input error raises + with pytest.raises(ValueError): + au.get_wv_map_bounds(wave_flip, dispersion_axis=2) + + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wave_p_or_m(dispersion_axis): + return \ No newline at end of file From fbd4521d91142ab74fb6f184968f706cefaef5e9 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 25 Oct 2024 17:11:24 -0400 Subject: [PATCH 02/35] simplify oversample_grid and write tests --- jwst/extract_1d/soss_extract/atoca_utils.py | 42 +++------ .../soss_extract/tests/test_atoca_utils.py | 85 ++++++++++++++++++- 2 files changed, 93 insertions(+), 34 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 4960d1b33d..06a40f1437 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -199,10 +199,8 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1): return wave_top, wave_bottom -def oversample_grid(wave_grid, n_os=1): +def oversample_grid(wave_grid, n_os): """ - TODO: can this be replaced by np.interp or similar? - Create an oversampled version of the input 1D wavelength grid. Parameters @@ -220,38 +218,20 @@ def oversample_grid(wave_grid, n_os=1): The oversampled wavelength grid. """ - # Convert n_os to an array. + # Convert n_os to an array of size len(wave_grid) - 1. n_os = np.asarray(n_os) - - # n_os needs to have the dimension: len(wave_grid) - 1. if n_os.ndim == 0: - - # A scalar was given, repeat the value. n_os = np.repeat(n_os, len(wave_grid) - 1) - elif len(n_os) != (len(wave_grid) - 1): - # An array of incorrect size was given. msg = 'n_os must be a scalar or an array of size len(wave_grid) - 1.' log.critical(msg) raise ValueError(msg) - - # Grid intervals. - delta_wave = np.diff(wave_grid) - - # Initialize the new oversampled wavelength grid. - wave_grid_os = wave_grid.copy() - - # Iterate over oversampling factors to generate new grid points. - for i_os in range(1, n_os.max()): - - # Consider only intervals that are not complete yet. - mask = n_os > i_os - - # Compute the new grid points. - sub_grid = wave_grid[:-1][mask] + (i_os * delta_wave[mask] / n_os[mask]) - - # Add the grid points to the oversampled wavelength grid. - wave_grid_os = np.concatenate([wave_grid_os, sub_grid]) + + # Compute the oversampled grid. + intervals = 1/n_os + intervals = np.insert(np.repeat(intervals, n_os),0,0) + grid = np.cumsum(intervals) + wave_grid_os = np.interp(grid, np.arange(wave_grid.size), wave_grid) # Take only unique values and sort them. return np.unique(wave_grid_os) @@ -381,10 +361,8 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): wave_range : list[float] Minimum and maximum boundary of the grid to generate, in microns. Wave_range must include some wavelengths of wave_map. - n_os : int or list[int] - Oversampling of the grid compare to the pixel sampling. Can be - specified for each order if a list is given. If a single value is given - it will be used for all orders. + n_os : int + Oversampling of the grid compare to the pixel sampling. poly_ord : int Order of the polynomial use to extrapolate the grid. diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index d7bf689edf..07240301e0 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -27,6 +27,7 @@ def test_arange_2d(): with pytest.raises(ValueError): au.arange_2d(starts, stops_too_small) + FIBONACCI = np.array([1,1,2,3,5,8,13], dtype=float) @pytest.fixture(scope="module") def wave_map(): @@ -77,5 +78,85 @@ def test_get_wv_map_bounds(wave_map, dispersion_axis): @pytest.mark.parametrize("dispersion_axis", [0,1]) -def test_get_wave_p_or_m(dispersion_axis): - return \ No newline at end of file +def test_get_wave_p_or_m(wave_map, dispersion_axis): + """ + Check that the plus and minus side is correctly identified + for strictly ascending and strictly descending wavelengths. + """ + wave_reverse = np.fliplr(wave_map) + if dispersion_axis == 0: + wave_flip = wave_map.T + wave_reverse = wave_reverse.T + else: + wave_flip = wave_map + + wave_p_0, wave_m_0 = au.get_wave_p_or_m(wave_flip, dispersion_axis=dispersion_axis) + wave_p_1, wave_m_1 = au.get_wave_p_or_m(wave_reverse, dispersion_axis=dispersion_axis) + + if dispersion_axis==0: + wave_p_0 = wave_p_0.T + wave_m_0 = wave_m_0.T + wave_p_1 = wave_p_1.T + wave_m_1 = wave_m_1.T + assert np.all(wave_p_0 >= wave_m_0) + assert np.allclose(wave_p_0, np.fliplr(wave_p_1)) + assert np.allclose(wave_m_0, np.fliplr(wave_m_1)) + + +def test_get_wave_p_or_m_not_ascending(wave_map): + with pytest.raises(ValueError): + wave_map[0,5] = 2 # make it not strictly ascending + au.get_wave_p_or_m(wave_map, dispersion_axis=1) + + +@pytest.mark.parametrize("n_os", [1,5]) +def test_oversample_grid(n_os): + + oversample = au.oversample_grid(FIBONACCI, n_os) + + # oversample_grid is supposed to remove any duplicates, and there is a duplicate + # in FIBONACCI. So the output should be 4 times the size of FIBONACCI minus 1 + assert oversample.size == n_os*(FIBONACCI.size - 1) - (n_os-1) + assert oversample.min() == FIBONACCI.min() + assert oversample.max() == FIBONACCI.max() + + # test whether np.interp could have been used instead + grid = np.arange(0, FIBONACCI.size, 1/n_os) + wls = np.unique(np.interp(grid, np.arange(FIBONACCI.size), FIBONACCI)) + assert np.allclose(oversample, wls) + + +@pytest.mark.parametrize("os_factor", [1,2,5]) +def test_oversample_irregular(os_factor): + """Test oversampling to a grid with irregular spacing""" + # oversampling function removes duplicates, + # this is tested in previous test, and just complicates counting for this test + # for FIBONACCI, unique is just removing zeroth element + fib_unq = np.unique(FIBONACCI) + n_os = np.ones((fib_unq.size-1,), dtype=int) + n_os[2:5] = os_factor + n_os[3] = os_factor*2 + # this gives n_os = [1 1 2 4 2] for os_factor = 2 + + oversample = au.oversample_grid(fib_unq, n_os) + + # test no oversampling was done on the elements where not requested + assert np.allclose(oversample[0:2], fib_unq[0:2]) + assert np.allclose(oversample[-1:], fib_unq[-1:]) + + # test output shape. + assert oversample.size == np.sum(n_os)+1 + + # test that this could have been done easily with np.interp + intervals = 1/n_os + intervals = np.insert(np.repeat(intervals, n_os),0,0) + grid = np.cumsum(intervals) + wls = np.interp(grid, np.arange(fib_unq.size), fib_unq) + assert wls.size == oversample.size + assert np.allclose(oversample, wls) + + # test that n_os shape must match input shape - 1 + with pytest.raises(ValueError): + au.oversample_grid(fib_unq, n_os[:-1]) + + From 7cb20b1af7862075f68473b1fa933ec80d06e87e Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 25 Oct 2024 19:57:34 -0400 Subject: [PATCH 03/35] fixed edge case where extrapolate_grid could run forever, added unit tests for that function --- jwst/extract_1d/soss_extract/atoca_utils.py | 36 ++++++++++--------- .../soss_extract/tests/test_atoca_utils.py | 22 ++++++++++++ 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 06a40f1437..edb61a54f9 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -237,18 +237,14 @@ def oversample_grid(wave_grid, n_os): return np.unique(wave_grid_os) -def extrapolate_grid(wave_grid, wave_range, poly_ord): - """ - TODO: It looks like these while loops and if statements can be removed and - replaced by something that operates on the whole array at once - e.g. the p.linspace function in numpy.polynomial.polynomial (see link below) - TODO: np.polyfit is considered legacy now, should be replaced by - https://numpy.org/doc/stable/reference/routines.polynomials-package.html - - Extrapolate the given 1D wavelength grid to cover a given range of values +def extrapolate_grid(wave_grid, wave_range, poly_ord=1): + """Extrapolate the given 1D wavelength grid to cover a given range of values by fitting the derivative with a polynomial of a given order and using it to compute subsequent values at both ends of the grid. + TODO: np.polyfit is considered legacy now, should be replaced by + https://numpy.org/doc/stable/reference/routines.polynomials-package.html + Parameters ---------- wave_grid : array[float] @@ -263,12 +259,14 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord): wave_grid_ext : array[float] The extrapolated 1D wavelength grid. """ - # Define delta_wave as a function of wavelength by fitting a polynomial. delta_wave = np.diff(wave_grid) pars = np.polyfit(wave_grid[:-1], delta_wave, poly_ord) f_delta = np.poly1d(pars) + # Set a minimum delta value to avoid running forever + min_delta = delta_wave.min()/10 + # Extrapolate out-of-bound values on the left-side of the grid. grid_left = [] if wave_range[0] < wave_grid.min(): @@ -278,12 +276,14 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord): # Iterate until the end of wave_range is reached. while True: - next_val = grid_left[-1] - f_delta(grid_left[-1]) + next_delta = f_delta(grid_left[-1]) + next_val = grid_left[-1] - next_delta + grid_left.append(next_val) if next_val < wave_range[0]: break - else: - grid_left.append(next_val) + if next_delta < min_delta: + raise RuntimeError('Extrapolation failed to converge.') # Sort extrapolated vales (and keep only unique). grid_left = np.unique(grid_left) @@ -297,13 +297,15 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord): # Iterate until the end of wave_range is reached. while True: - next_val = grid_right[-1] + f_delta(grid_right[-1]) + next_delta = f_delta(grid_right[-1]) + next_val = grid_right[-1] + next_delta + grid_right.append(next_val) if next_val > wave_range[-1]: break - else: - grid_right.append(next_val) - + if next_delta < min_delta: + raise RuntimeError('Extrapolation failed to converge.') + # Sort extrapolated vales (and keep only unique). grid_right = np.unique(grid_right) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index 07240301e0..62c36b1b2e 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -160,3 +160,25 @@ def test_oversample_irregular(os_factor): au.oversample_grid(fib_unq, n_os[:-1]) +@pytest.mark.parametrize("wave_range", [(2.1, 2.9), (1.8, 3.5)]) +def test_extrapolate_grid(wave_range): + + # give wavelengths some non-linearity + n=50 + wavelengths = np.linspace(2.0, 3.0, n) + np.sin(np.linspace(0, np.pi/2, n)) + poly_ord = 1 + extrapolated = au.extrapolate_grid(wavelengths, wave_range, poly_ord) + + assert extrapolated.max() > wave_range[1] + assert extrapolated.min() < wave_range[0] + assert np.all(extrapolated[1:] >= extrapolated[:-1]) + + +def test_extrapolate_catch_failed_converge(): + # give wavelengths some non-linearity + n=50 + wavelengths = np.linspace(2.0, 3.0, n) + np.sin(np.linspace(0, np.pi/2, n)) + wave_range = wavelengths.min(), wavelengths.max()+2.0 + poly_ord = 1 + with pytest.raises(RuntimeError): + au.extrapolate_grid(wavelengths, wave_range, poly_ord) \ No newline at end of file From a69282a6c532a0bfe17216f97de018f124d9b74a Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 28 Oct 2024 11:10:49 -0400 Subject: [PATCH 04/35] replace legacy scipy interp1d with make_interp_spline --- jwst/extract_1d/soss_extract/atoca_utils.py | 66 +++++++++++++++---- jwst/extract_1d/soss_extract/soss_extract.py | 21 +++--- .../soss_extract/tests/test_atoca_utils.py | 36 ++++++---- 3 files changed, 92 insertions(+), 31 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index edb61a54f9..14b80bcf67 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -11,6 +11,7 @@ from scipy.sparse.linalg import spsolve from scipy.interpolate import interp1d, RectBivariateSpline, Akima1DInterpolator from scipy.optimize import minimize_scalar, brentq +from scipy.interpolate import make_interp_spline import logging log = logging.getLogger(__name__) @@ -259,6 +260,17 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord=1): wave_grid_ext : array[float] The extrapolated 1D wavelength grid. """ + if wave_range[0] >= wave_range[-1]: + msg = 'wave_range must be in order [short, long].' + log.critical(msg) + raise ValueError(msg) + if wave_range[0] > wave_grid.max() or wave_range[-1] < wave_grid.min(): + msg = 'wave_range must overlap with wave_grid.' + log.critical(msg) + raise ValueError(msg) + if wave_range[0] > wave_grid.min() and wave_range[-1] < wave_grid.max(): + return wave_grid + # Define delta_wave as a function of wavelength by fitting a polynomial. delta_wave = np.diff(wave_grid) pars = np.polyfit(wave_grid[:-1], delta_wave, poly_ord) @@ -348,11 +360,10 @@ def _grid_from_map(wave_map, trace_profile): return grid, icols[sort] -def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): +def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): """Define a wavelength grid by taking the central wavelength at each columns given by the center of mass of the spatial profile (so one wavelength per - column). If wave_range is outside of the wave_map, extrapolate with a - polynomial of order poly_ord. + column). If wave_range is outside of the wave_map, extrapolate. Parameters ---------- @@ -365,14 +376,17 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): Wave_range must include some wavelengths of wave_map. n_os : int Oversampling of the grid compare to the pixel sampling. - poly_ord : int - Order of the polynomial use to extrapolate the grid. Returns ------- grid_os : array[float] Wavelength grid with oversampling applied """ + import matplotlib.pyplot as plt + plt.imshow(wave_map, origin = 'lower') + plt.show() + plt.imshow(trace_profile, origin = 'lower') + plt.show() # Different treatment if wave_range is given. if wave_range is None: @@ -381,9 +395,6 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): # Get an initial estimate of the grid. grid, icols = _grid_from_map(wave_map, trace_profile) - # Check if extrapolation needed. If so, out_col must be False. - extrapolate = (wave_range[0] < grid.min()) | (wave_range[1] > grid.max()) - # Make sure grid is between the range mask = (wave_range[0] <= grid) & (grid <= wave_range[-1]) @@ -396,10 +407,8 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1, poly_ord=1): grid, icols = grid[mask], icols[mask] # Extrapolate values out of the wv_map if needed - if extrapolate: - grid = extrapolate_grid(grid, wave_range, poly_ord) - - out = grid + # if not needed, this will return the original grid + out = extrapolate_grid(grid, wave_range, poly_ord=1) # Apply oversampling return oversample_grid(out, n_os=n_os) @@ -857,6 +866,39 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): return grid, is_converged +def ThroughputSOSS(wavelength, throughput): + """ + Parameters + ---------- + wavelength : array[float] + A wavelength array. + + throughput : array[float] + The throughput values corresponding to the wavelengths. + + Returns + ------- + interpolator : callable + A function that interpolates the throughput values. Accepts an array + of wavelengths and returns the interpolated throughput values. + + Notes + ----- + Clamped boundary condition corresponds to a first derivative of zero, + which ensures a smooth curve given zero-padding outside the range. + """ + wl_min, wl_max = np.min(wavelength), np.max(wavelength) + spline = make_interp_spline(wavelength, throughput, k=3, bc_type=("clamped", "clamped")) + + def interpolator(wl): + thru = np.zeros_like(wl) + inside = (wl > wl_min) & (wl < wl_max) + thru[inside] = spline(wl[inside]) + return thru + + return interpolator + + class WebbKernel: # TODO could probably be cleaned-up somewhat, may need further adjustment. def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels may need to be flipped? diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 9d780f0fdd..7605afedf1 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -1,7 +1,7 @@ import logging import numpy as np -from scipy.interpolate import make_interp_spline + from scipy.interpolate import UnivariateSpline, CubicSpline from stdatamodels.jwst import datamodels @@ -13,7 +13,7 @@ from .soss_syscor import make_background_mask, soss_background from .atoca import ExtractionEngine, MaskOverlapError -from .atoca_utils import (WebbKernel, grid_from_map, +from .atoca_utils import (ThroughputSOSS, WebbKernel, grid_from_map, make_combined_adaptive_grid, get_wave_p_or_m, oversample_grid) from .soss_boxextract import get_box_weights, box_extract, estim_error_nearest_data from .pastasoss import get_soss_wavemaps, XTRACE_ORD1_LEN, XTRACE_ORD2_LEN @@ -82,12 +82,17 @@ def get_ref_file_args(ref_files): throughput_index_dict[throughput.spectral_order] = i # TODO: check that replacing legacy interp1d with make_interp_spline works right - throughput_o1 = make_interp_spline(pastasoss_ref.throughputs[throughput_index_dict[1]].wavelength[:], - pastasoss_ref.throughputs[throughput_index_dict[1]].throughput[:], - k=3, fill_value=0.0, bounds_error=False) - throughput_o2 = make_interp_spline(pastasoss_ref.throughputs[throughput_index_dict[2]].wavelength[:], - pastasoss_ref.throughputs[throughput_index_dict[2]].throughput[:], - k=3, fill_value=0.0, bounds_error=False) + # this is equivalent to order 3 interpolation with scipy interp1d except that + # fill_value=0.0 has been removed because make_interp_spline does not support it + # was the fill_value ever used? since this is a throughput there's likely nothing wrong with it + # What are the appropriate boundary conditions to pass to make_interp_spline bc_type? + # since expectation is all zeros outside range, first derivative should be zero at boundaries + # one idea is to retain ThroughputSOSS class and make it return a function that calls + # output of make_interp_spline when inside range and returns zero outside range + throughput_o1 = ThroughputSOSS(pastasoss_ref.throughputs[throughput_index_dict[1]].wavelength[:], + pastasoss_ref.throughputs[throughput_index_dict[1]].throughput[:]) + throughput_o2 = ThroughputSOSS(pastasoss_ref.throughputs[throughput_index_dict[2]].wavelength[:], + pastasoss_ref.throughputs[throughput_index_dict[2]].throughput[:]) # The spectral kernels. speckernel_ref = ref_files['speckernel'] diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index 62c36b1b2e..addaea7d73 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -160,25 +160,39 @@ def test_oversample_irregular(os_factor): au.oversample_grid(fib_unq, n_os[:-1]) -@pytest.mark.parametrize("wave_range", [(2.1, 2.9), (1.8, 3.5)]) +# wavelengths have min, max (2.0, 4.0) and a bit of non-linearity +WAVELENGTHS = np.linspace(2.0, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) +@pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) def test_extrapolate_grid(wave_range): - # give wavelengths some non-linearity - n=50 - wavelengths = np.linspace(2.0, 3.0, n) + np.sin(np.linspace(0, np.pi/2, n)) - poly_ord = 1 - extrapolated = au.extrapolate_grid(wavelengths, wave_range, poly_ord) + extrapolated = au.extrapolate_grid(WAVELENGTHS, wave_range, 1) assert extrapolated.max() > wave_range[1] assert extrapolated.min() < wave_range[0] assert np.all(extrapolated[1:] >= extrapolated[:-1]) + # if interpolation not needed on either side, should return the original + if wave_range == (2.1, 3.9): + assert extrapolated is WAVELENGTHS + def test_extrapolate_catch_failed_converge(): # give wavelengths some non-linearity - n=50 - wavelengths = np.linspace(2.0, 3.0, n) + np.sin(np.linspace(0, np.pi/2, n)) - wave_range = wavelengths.min(), wavelengths.max()+2.0 - poly_ord = 1 + wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+2.0 with pytest.raises(RuntimeError): - au.extrapolate_grid(wavelengths, wave_range, poly_ord) \ No newline at end of file + au.extrapolate_grid(WAVELENGTHS, wave_range, 1) + + +def test_extrapolate_bad_inputs(): + with pytest.raises(ValueError): + au.extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) + with pytest.raises(ValueError): + au.extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) + + +def test_grid_from_map(wave_map): + + pass + #wave_grid = au.grid_from_map(wave_map, trace_profile) + + From 2dc62d330b22198011c9a0b2472e41b9b96cf9f7 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 28 Oct 2024 17:20:17 -0400 Subject: [PATCH 05/35] add unit tests to get_soss_grid and helpers --- jwst/extract_1d/soss_extract/atoca_utils.py | 125 +++++++++-------- .../soss_extract/tests/test_atoca_utils.py | 132 ++++++++++++++++-- 2 files changed, 187 insertions(+), 70 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 14b80bcf67..1f87756056 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -6,6 +6,7 @@ @authors: Antoine Darveau-Bernier, Geert Jan Talens """ +import warnings import numpy as np from scipy.sparse import diags, csr_matrix from scipy.sparse.linalg import spsolve @@ -283,8 +284,9 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord=1): grid_left = [] if wave_range[0] < wave_grid.min(): - # Compute the first extrapolated grid point. - grid_left = [wave_grid.min() - f_delta(wave_grid.min())] + # Initialize extrapolated grid with the first value of input grid. + # This point gets double-counted in the final grid, but then unique is called. + grid_left = [wave_grid.min(),] # Iterate until the end of wave_range is reached. while True: @@ -304,8 +306,9 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord=1): grid_right = [] if wave_range[-1] > wave_grid.max(): - # Compute the first extrapolated grid point. - grid_right = [wave_grid.max() + f_delta(wave_grid.max())] + # Initialize extrapolated grid with the last value of input grid. + # This point gets double-counted in the final grid, but then unique is called. + grid_right = [wave_grid.max(),] # Iterate until the end of wave_range is reached. while True: @@ -318,7 +321,7 @@ def extrapolate_grid(wave_grid, wave_range, poly_ord=1): if next_delta < min_delta: raise RuntimeError('Extrapolation failed to converge.') - # Sort extrapolated vales (and keep only unique). + # Sort extrapolated values (and keep only unique) grid_right = np.unique(grid_right) # Combine the extrapolated sections with the original grid. @@ -344,26 +347,28 @@ def _grid_from_map(wave_map, trace_profile): Column indices used. """ - # Use only valid columns. - mask = (trace_profile > 0).any(axis=0) & (wave_map > 0).any(axis=0) + # Use only valid values by setting weights to zero + trace_profile[trace_profile < 0] = 0 + trace_profile[wave_map <= 0] = 0 - # Get central wavelength using PSF as weights. - num = (trace_profile * wave_map).sum(axis=0) - denom = trace_profile.sum(axis=0) - center_wv = num[mask] / denom[mask] + # handle case where all values are invalid for a given wavelength + # np.average cannot process sum(weights) = 0, so set them to unity then set NaN afterward + bad_wls = np.sum(trace_profile, axis=0) == 0 + trace_profile[:,bad_wls] = 1 + center_wv = np.average(wave_map, weights=trace_profile, axis=0) + center_wv[bad_wls] = np.nan + center_wv = center_wv[~np.isnan(center_wv)] # Make sure the wavelength values are in ascending order. - sort = np.argsort(center_wv) - grid = center_wv[sort] - - icols, = np.where(mask) - return grid, icols[sort] + return np.sort(center_wv) def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): """Define a wavelength grid by taking the central wavelength at each columns given by the center of mass of the spatial profile (so one wavelength per column). If wave_range is outside of the wave_map, extrapolate. + TODO: question for SOSS team: I doubt it matters much, but is the current + behavior of wave_range ok or should it be inclusive? Parameters ---------- @@ -374,6 +379,10 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): wave_range : list[float] Minimum and maximum boundary of the grid to generate, in microns. Wave_range must include some wavelengths of wave_map. + Note wave_range is exclusive, in the sense that wave_range[0] and wave_range[1] + will not be between min(output) and mad(output). Instead, min(output) will be + the smallest value in the extrapolated grid that is greater than wave_range[0] + and max(output) will be the largest value that is less than wave_range[1]. n_os : int Oversampling of the grid compare to the pixel sampling. @@ -382,39 +391,36 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): grid_os : array[float] Wavelength grid with oversampling applied """ - import matplotlib.pyplot as plt - plt.imshow(wave_map, origin = 'lower') - plt.show() - plt.imshow(trace_profile, origin = 'lower') - plt.show() + if wave_map.shape != trace_profile.shape: + msg = 'wave_map and trace_profile must have the same shape.' + log.critical(msg) + raise ValueError(msg) # Different treatment if wave_range is given. if wave_range is None: - out, _ = _grid_from_map(wave_map, trace_profile) + grid = _grid_from_map(wave_map, trace_profile) else: # Get an initial estimate of the grid. - grid, icols = _grid_from_map(wave_map, trace_profile) + grid = _grid_from_map(wave_map, trace_profile) + + # Extrapolate values out of the wv_map if needed + grid = extrapolate_grid(grid, wave_range, poly_ord=1) - # Make sure grid is between the range - mask = (wave_range[0] <= grid) & (grid <= wave_range[-1]) + # Constrain grid to be within wave_range + grid = grid[grid>=wave_range[0]] + grid = grid[grid<=wave_range[-1]] # Check if grid and wv_range are compatible - if not mask.any(): + if len(grid) == 0: msg = "Invalid wave_map or wv_range." log.critical(msg) raise ValueError(msg) - grid, icols = grid[mask], icols[mask] - - # Extrapolate values out of the wv_map if needed - # if not needed, this will return the original grid - out = extrapolate_grid(grid, wave_range, poly_ord=1) - # Apply oversampling - return oversample_grid(out, n_os=n_os) + return oversample_grid(grid, n_os=n_os) -def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=None): +def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=2): """Create a wavelength grid specific to NIRISS SOSS mode observations. Assumes 2 orders are given, use grid_from_map if only one order is needed. @@ -422,16 +428,20 @@ def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=N ---------- wave_maps : array[float] Array containing the pixel wavelengths for order 1 and 2. + trace_profiles : array[float] Array containing the spatial profiles for order 1 and 2. + wave_min : float Minimum wavelength the output grid should cover. + wave_max : float Maximum wavelength the output grid should cover. + n_os : int or list[int] Oversampling of the grid compared to the pixel sampling. Can be specified for each order if a list is given. If a single value is given - it will be used for all orders. + it will be used for all orders. Default is 2 for all orders. Returns ------- @@ -439,37 +449,44 @@ def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=N Wavelength grid optimized for extracting SOSS spectra across order 1 and order 2. """ - - # Check n_os input, default value is 2 for all orders. - if n_os is None: - n_os = [2, 2] - elif np.ndim(n_os) == 0: + if np.ndim(n_os) == 0: n_os = [n_os, n_os] elif len(n_os) != 2: msg = (f"n_os must be an integer or a 2 element list or array of " f"integers, got {n_os} instead") log.critical(msg) raise ValueError(msg) + if (wave_maps.shape[0] != 2) or (trace_profiles.shape[0] != 2): + msg = 'wave_maps and trace_profiles must have shape (2, detector_x, detector_y).' + log.critical(msg) + raise ValueError(msg) + + wave_maps[wave_maps <= 0] = np.nan + trace_profiles[trace_profiles < 0] = 0 + if np.any(np.isnan(wave_maps)): + msg = 'Encountered NaN values in wave_maps, ignoring them.' + log.warning(msg) # Generate a wavelength range for each order. # Order 1 covers the reddest part of the spectrum, # so apply wave_max on order 1 and vice versa for order 2. + # TODO: This doesn't know about throughput, so it's possible these min/max wavelengths + # would be chosen such that the algorithm attemps to extract a wavelength from an order + # that has very low (but non-zero) throughput at that wavelength - # Take the most restrictive wave_min for order 1 - wave_min_o1 = np.maximum(wave_maps[0].min(), wave_min) + # current test data has min(wave_map) != min(wave_grid) because of throughput! + # how to fix this? - # Take the most restrictive wave_max for order 2. - wave_max_o2 = np.minimum(wave_maps[1].max(), wave_max) - - # Now generate range for each orders - range_list = [[wave_min_o1, wave_max], - [wave_min, wave_max_o2]] + wave_min_o1 = np.maximum(np.nanmin(wave_maps[0]), wave_min) + wave_max_o2 = np.minimum(np.nanmax(wave_maps[1]), wave_max) # Use grid_from_map to construct separate oversampled grids for both orders. wave_grid_o1 = grid_from_map(wave_maps[0], trace_profiles[0], - wave_range=range_list[0], n_os=n_os[0]) + wave_range=[wave_min_o1, wave_max], + n_os=n_os[0]) wave_grid_o2 = grid_from_map(wave_maps[1], trace_profiles[1], - wave_range=range_list[1], n_os=n_os[1]) + wave_range=[wave_min, wave_max_o2], + n_os=n_os[1]) # Keep only wavelengths in order 1 that aren't covered by order 2. mask = wave_grid_o1 > wave_grid_o2.max() @@ -479,9 +496,7 @@ def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=N wave_grid_soss = np.concatenate([wave_grid_o1, wave_grid_o2]) # Sort values (and keep only unique). - wave_grid_soss = np.unique(wave_grid_soss) - - return wave_grid_soss + return np.unique(wave_grid_soss) def _trim_grids(all_grids, grid_range=None): @@ -2121,7 +2136,7 @@ def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): # Define the number of data points # (length of the "b" vector in the tikhonov regularisation) if test_dict is None: - print('Unable to get the number of data points. Setting `n_points` to 1') + log.warning('Unable to get the number of data points. Setting `n_points` to 1') n_points = 1 else: n_points = len(test_dict['error'][0].squeeze()) @@ -2473,7 +2488,7 @@ def test_factors(self, factors): message = '{}/{}'.format(i_fac, len(factors)) log.info(message) - # Final print + # Final message output message = '{}/{}'.format(i_fac + 1, len(factors)) log.info(message) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index addaea7d73..f0b96a33cd 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -28,18 +28,40 @@ def test_arange_2d(): au.arange_2d(starts, stops_too_small) -FIBONACCI = np.array([1,1,2,3,5,8,13], dtype=float) -@pytest.fixture(scope="module") +# wavelengths have min, max (1.5, 4.0) and a bit of non-linearity +WAVELENGTHS = np.linspace(1.5, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) +@pytest.fixture(scope="function") def wave_map(): wave_map = np.array([ - FIBONACCI, - FIBONACCI+1, - + WAVELENGTHS, + WAVELENGTHS+0.2, + WAVELENGTHS+0.2, + WAVELENGTHS+0.2, + WAVELENGTHS+0.4, ]) wave_map[1,3] = -1 #test skip of bad value return wave_map +@pytest.fixture(scope="function") +def wave_map_o2(wave_map): + return np.copy(wave_map) - 1.0 + + +@pytest.fixture(scope="function") +def trace_profile(wave_map): + thrpt = np.array([0.01, 0.95, 1.0, 0.8, 0.01]) + trace_profile = np.ones_like(wave_map) + return trace_profile*thrpt[:,None] + + +@pytest.fixture(scope="function") +def trace_profile_o2(wave_map_o2): + thrpt = np.array([0.001, 0.01, 0.01, 0.2, 0.99]) + trace_profile = np.ones_like(wave_map_o2) + return trace_profile*thrpt[:,None] + + @pytest.mark.parametrize("dispersion_axis", [0,1]) def test_get_wv_map_bounds(wave_map, dispersion_axis): """ @@ -56,14 +78,14 @@ def test_get_wv_map_bounds(wave_map, dispersion_axis): wave_top = wave_top.T wave_bottom = wave_bottom.T - diff = (FIBONACCI[1:]-FIBONACCI[:-1])/2 + diff = (WAVELENGTHS[1:]-WAVELENGTHS[:-1])/2 diff_lower = np.insert(diff,0,diff[0]) diff_upper = np.append(diff,diff[-1]) - wave_top_expected = FIBONACCI-diff_lower - wave_bottom_expected = FIBONACCI+diff_upper + wave_top_expected = WAVELENGTHS-diff_lower + wave_bottom_expected = WAVELENGTHS+diff_upper # basic test - assert wave_top.shape == wave_bottom.shape == (2,)+FIBONACCI.shape + assert wave_top.shape == wave_bottom.shape == (wave_map.shape[0],)+WAVELENGTHS.shape assert np.allclose(wave_top[0], wave_top_expected) assert np.allclose(wave_bottom[0], wave_bottom_expected) @@ -109,6 +131,7 @@ def test_get_wave_p_or_m_not_ascending(wave_map): au.get_wave_p_or_m(wave_map, dispersion_axis=1) +FIBONACCI = np.array([1,1,2,3,5,8,13,21,35], dtype=float) @pytest.mark.parametrize("n_os", [1,5]) def test_oversample_grid(n_os): @@ -160,8 +183,6 @@ def test_oversample_irregular(os_factor): au.oversample_grid(fib_unq, n_os[:-1]) -# wavelengths have min, max (2.0, 4.0) and a bit of non-linearity -WAVELENGTHS = np.linspace(2.0, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) @pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) def test_extrapolate_grid(wave_range): @@ -178,7 +199,7 @@ def test_extrapolate_grid(wave_range): def test_extrapolate_catch_failed_converge(): # give wavelengths some non-linearity - wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+2.0 + wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 with pytest.raises(RuntimeError): au.extrapolate_grid(WAVELENGTHS, wave_range, 1) @@ -190,9 +211,90 @@ def test_extrapolate_bad_inputs(): au.extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) -def test_grid_from_map(wave_map): +def test_grid_from_map(wave_map, trace_profile): + """Covers expected behavior of grid_from_map, including coverage of a previous bug + where bad wavelengths were not being ignored properly""" + + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=None) + + # expected output is very near WAVELENGTHS+0.2 because that's what all the high-weight + # rows of the wave_map are set to. + assert np.allclose(wave_grid, wave_map[2]) + + # test custom wave_range + wave_range = [wave_map[2,2], wave_map[2,-2]+0.01] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert np.allclose(wave_grid, wave_map[2,2:-1]) + + # test custom wave_range with extrapolation + wave_range = [wave_map[2,2], wave_map[2,-1]+1] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert len(wave_grid) > len(wave_map[2,2:]) + n_inside = wave_map[2,2:].size + assert np.allclose(wave_grid[:n_inside], wave_map[2,2:]) + + with pytest.raises(ValueError): + au.grid_from_map(wave_map, trace_profile, wave_range=[0.5,0.9]) - pass - #wave_grid = au.grid_from_map(wave_map, trace_profile) + +@pytest.mark.parametrize("n_os", ([4,1], 1)) +def test_get_soss_grid(n_os, wave_map, trace_profile, wave_map_o2, trace_profile_o2): + """ + wave_map has min, max wavelength of 1.5, 4.0 but throughput makes this 1.7, 4.2 + wave_map_o2 has min, max wavelength of 0.5, 3.0 but throughput makes this 0.9, 3.4 + """ + + # choices of wave_min, wave_max force extrapolation on low end + # and also makes sure both orders matter + wave_min = 0.55 + wave_max = 4.0 + wave_maps = np.array([wave_map, wave_map_o2]) + trace_profiles = np.array([trace_profile, trace_profile_o2]) + wave_grid = au.get_soss_grid(wave_maps, trace_profiles, wave_min, wave_max, n_os) + + delta_lower = wave_grid[1]-wave_grid[0] + delta_upper = wave_grid[-1]-wave_grid[-2] + + # ensure no duplicates and strictly ascending + assert wave_grid.size == np.unique(wave_grid).size + assert np.all(wave_grid[1:] > wave_grid[:-1]) + + # esnure grid is within bounds but just abutting bounds + assert wave_grid.min() >= wave_min + assert wave_grid.min() <= wave_min + delta_lower + assert wave_grid.max() <= wave_max + assert wave_grid.max() >= wave_max - delta_upper + + # ensure oversample factor changes for different n_os + # this is a bit complicated because the wavelength spacing is nonlinear in wave_map + # by a factor of 2 to begin with, so check that the ratio of the wavelength spacing + # in the upper vs lower end of the wl ranges look approx like what was input, modulo + # the oversample factor of the two orders + if n_os == 1: + n_os = [1,1] + og_spacing_lower = WAVELENGTHS[1]-WAVELENGTHS[0] + og_spacing_upper = WAVELENGTHS[-1]-WAVELENGTHS[-2] + expected_ratio = int((n_os[0]/n_os[1])*(np.around(og_spacing_lower/og_spacing_upper))) + + spacing_lower = np.mean(wave_grid[1:6]-wave_grid[:5]) + spacing_upper = np.mean(wave_grid[-5:]-wave_grid[-6:-1]) + actual_ratio = int(np.around((spacing_lower/spacing_upper))) + # for n=1 case we expect 2, for n=(4,1) case we expect 8 + assert expected_ratio == actual_ratio + + +def test_get_soss_grid_bad_inputs(wave_map, trace_profile): + with pytest.raises(ValueError): + # test bad input shapes + au.get_soss_grid(wave_map, trace_profile, 0.5, 0.9, 1) + + wave_maps = np.array([wave_map, wave_map]) + trace_profiles = np.array([trace_profile, trace_profile]) + + with pytest.raises(ValueError): + # test bad n_os shape + au.get_soss_grid(wave_maps, trace_profiles, 0.5, 0.9, [1,1,1]) + + From 4197b38a922c0efb04ed046f05f637535c5bd9f1 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 28 Oct 2024 18:35:17 -0400 Subject: [PATCH 06/35] remove unused soss utilities --- jwst/extract_1d/soss_extract/atoca.py | 125 ++----------- jwst/extract_1d/soss_extract/atoca_utils.py | 1 + jwst/extract_1d/soss_extract/pastasoss.py | 52 ++--- .../extract_1d/soss_extract/soss_centroids.py | 150 --------------- jwst/extract_1d/soss_extract/soss_extract.py | 9 +- jwst/extract_1d/soss_extract/soss_syscor.py | 143 -------------- jwst/extract_1d/soss_extract/soss_utils.py | 177 ------------------ 7 files changed, 43 insertions(+), 614 deletions(-) delete mode 100644 jwst/extract_1d/soss_extract/soss_centroids.py delete mode 100644 jwst/extract_1d/soss_extract/soss_utils.py diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 8fa6908a5b..e8b5c20301 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -561,7 +561,9 @@ def wave_grid_c(self, i_order): return self.wave_grid[index] def get_w(self, i_order): - """Dummy method to init this class""" + """Dummy method to init this class + TODO: so is this an abstract base class? is it ever actually subclassed more than once? + """ return np.array([]), np.array([]) @@ -963,7 +965,10 @@ def set_tikho_matrix(self, t_mat=None, t_mat_func=None, fargs=None, fkwargs=None return def get_tikho_matrix(self, **kwargs): - """Return the Tikhonov matrix. + """ + #TODO: why is there also an atoca_utils.get_tikho_matrix function? + + Return the Tikhonov matrix. Generate it with `set_tikho_matrix` method if not defined yet. If so, all arguments are passed to `set_tikho_matrix`. The result is saved as an attribute. @@ -1319,6 +1324,9 @@ def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): There will be only one matrix multiplication: (P/sig).(w.T.lambda.c_n). + TODO: is there any way to avoid the need for this __call__ method given + that it's only a thin wrapper to the Tikhonov class? + Parameters ---------- tikhonov : bool, optional @@ -1378,117 +1386,6 @@ def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): return spectrum - def bin_to_pixel(self, i_order=0, grid_pix=None, grid_f_k=None, convolved_spectrum=None, - spectrum=None, bounds_error=False, throughput=None, **kwargs): - """Integrate the convolved_spectrum (f_k_c) over a pixel grid using the trapezoidal rule. - The convolved spectrum (f_k_c) is interpolated using scipy.interpolate.interp1d and the - kwargs and bounds_error are passed to interp1d. - i_order : int, optional - index of the order to be integrated, default is 0, so - the first order specified. - grid_pix : tuple or array, optional - If a tuple of 2 arrays is given, assume it is the lower and upper - integration ranges. If 1d-array, assume it is the center - of the pixels. If not given, the wavelength map and the psf map - of `i_order` will be used to compute a pixel grid. - grid_f_k : 1d array, optional - grid on which the convolved flux is projected. - Default is the wavelength grid for `i_order`. - convolved_spectrum : 1d array, optional - Convolved flux (f_k_c) to be integrated. If not given, `spectrum` - will be used (and convolved to `i_order` resolution) - spectrum : 1d array, optional - non-convolved flux (f_k, result of the `extract` method). - Not used if `convolved_spectrum` is specified. - bounds_error : bool, optional - passed to interp1d function to interpolate the convolved_spectrum. - Default is False - throughput : callable, optional - Spectral throughput for a given order (ì_ord). - Default is given by the list of throughput saved as - the attribute `t_list`. - kwargs : iterable, optional - If provided, will be passed to interp1d function. - - Returns - ------- - pix_center, bin_val : array[float] - The pixel centers and the associated integrated values. - """ - # Take the value from the order if not given... - - # ... for the flux grid ... - if grid_f_k is None: - grid_f_k = self.wave_grid_c(i_order) - - # ... for the convolved flux ... - if convolved_spectrum is None: - # Use the spectrum (f_k) if the convolved_spectrum (f_k_c) not given. - if spectrum is None: - raise ValueError("`spectrum` or `convolved_spectrum` must be specified.") - else: - # Convolve the spectrum (f_k). - convolved_spectrum = self.kernels[i_order].dot(spectrum) - - # ... and for the pixel bins - if grid_pix is None: - pix_center, _ = self.grid_from_map(i_order) - - # Get pixels borders (plus and minus) - pix_p, pix_m = atoca_utils.get_wave_p_or_m(pix_center) - - else: # Else, unpack grid_pix - - # Could be a scalar or a 2-elements object) - if len(grid_pix) == 2: - - # 2-elements object, so we have the borders - pix_m, pix_p = grid_pix - - # Need to compute pixel center - d_pix = (pix_p - pix_m) - pix_center = grid_pix[0] + d_pix - else: - - # 1-element object, so we have the pix centers - pix_center = grid_pix - - # Need to compute the borders - pix_p, pix_m = atoca_utils.get_wave_p_or_m(pix_center) - - # Set the throughput to object attribute - # if not given - if throughput is None: - - # Need to interpolate - x, y = self.wave_grid, self.throughput[i_order] - throughput = interp1d(x, y) - - # Apply throughput on flux - convolved_spectrum = convolved_spectrum * throughput(grid_f_k) - - # Interpolate - kwargs['bounds_error'] = bounds_error - fct_f_k = interp1d(grid_f_k, convolved_spectrum, **kwargs) - - # Intergrate over each bins - bin_val = [] - for x1, x2 in zip(pix_m, pix_p): - - # Grid points that fall inside the pixel range - i_grid = (x1 < grid_f_k) & (grid_f_k < x2) - x_grid = grid_f_k[i_grid] - - # Add boundaries values to the integration grid - x_grid = np.concatenate([[x1], x_grid, [x2]]) - - # Integrate - integrand = fct_f_k(x_grid) * x_grid - bin_val.append(np.trapezoid(integrand, x_grid)) - - # Convert to array and return with the pixel centers. - return pix_center, np.array(bin_val) - class ExtractionEngine(_BaseOverlap): """ @@ -1641,6 +1538,8 @@ def get_mask_wave(self, i_order): def get_w(self, i_order): """Compute integration weights for each grid points and each pixels. Depends on the order `n`. + TODO: is this the same order as the spectral order? if so, can we ignore + a bunch of the cases in this function and just keep orders 1 and 2? Parameters ---------- diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 1f87756056..2502ef8967 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -556,6 +556,7 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, max_iter=10, rtol=10e-6, tol=0.0, max_total_size=1000000): """ TODO: can this be a class? e.g., class AdaptiveGrid? + TODO: why aren't any of the same helper functions used here as in get_soss_grid? Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. diff --git a/jwst/extract_1d/soss_extract/pastasoss.py b/jwst/extract_1d/soss_extract/pastasoss.py index 4030df2224..6fec75c778 100644 --- a/jwst/extract_1d/soss_extract/pastasoss.py +++ b/jwst/extract_1d/soss_extract/pastasoss.py @@ -14,7 +14,7 @@ WAVEMAP_NWL = 5001 SUBARRAY_YMIN = 2048 - 256 -def get_wavelengths(refmodel, x, pwcpos, order): +def _get_wavelengths(refmodel, x, pwcpos, order): """Get the associated wavelength values for a given spectral order""" if order == 1: wavelengths = wavecal_model_order1_poly(refmodel, x, pwcpos) @@ -24,7 +24,7 @@ def get_wavelengths(refmodel, x, pwcpos, order): return wavelengths -def min_max_scaler(x, x_min, x_max): +def _min_max_scaler(x, x_min, x_max): """ Apply min-max scaling to input values. @@ -57,6 +57,9 @@ def min_max_scaler(x, x_min, x_max): def wavecal_model_order1_poly(refmodel, x, pwcpos): """Compute order 1 wavelengths. + TODO: surely there are numpy or scipy builtins to avoid having to write + out this polynomial expansion by hand? + Parameters ---------- refmodel : PastasossModel @@ -71,7 +74,7 @@ def wavecal_model_order1_poly(refmodel, x, pwcpos): to rotate the model """ x_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[0].scale_extents[0][0], "x_max": refmodel.wavecal_models[0].scale_extents[1][0], @@ -79,7 +82,7 @@ def wavecal_model_order1_poly(refmodel, x, pwcpos): ) pwcpos_offset_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[0].scale_extents[0][1], "x_max": refmodel.wavecal_models[0].scale_extents[1][1], @@ -148,7 +151,7 @@ def wavecal_model_order2_poly(refmodel, x, pwcpos): to rotate the model """ x_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[1].scale_extents[0][0], "x_max": refmodel.wavecal_models[1].scale_extents[1][0], @@ -156,7 +159,7 @@ def wavecal_model_order2_poly(refmodel, x, pwcpos): ) pwcpos_offset_scaler = partial( - min_max_scaler, + _min_max_scaler, **{ "x_min": refmodel.wavecal_models[1].scale_extents[0][1], "x_max": refmodel.wavecal_models[1].scale_extents[1][1], @@ -196,10 +199,13 @@ def get_poly_features(x, offset): return wavelengths -def rotate(x, y, angle, origin=(0, 0), interp=True): +def _rotate(x, y, angle, origin=(0, 0), interp=True): """ Applies a rotation transformation to a set of 2D points. + TODO: surely there are scipy builtins for this? is there + any difference between this and ndimage.rotate or something? + Parameters ---------- x : np.ndarray @@ -223,7 +229,7 @@ def rotate(x, y, angle, origin=(0, 0), interp=True): -------- >>> x = np.array([0, 1, 2, 3]) >>> y = np.array([0, 1, 2, 3]) - >>> x_rot, y_rot = rotate(x, y, 90) + >>> x_rot, y_rot = _rotate(x, y, 90) """ # shift to rotate about center @@ -251,7 +257,7 @@ def rotate(x, y, angle, origin=(0, 0), interp=True): return x_new, y_new -def find_spectral_order_index(refmodel, order): +def _find_spectral_order_index(refmodel, order): """Return index of trace and wavecal dict corresponding to order Parameters @@ -277,7 +283,7 @@ def find_spectral_order_index(refmodel, order): return -1 -def get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): +def _get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): """Generate the traces given a pupil wheel position. This is the primary method for generating the gr700xd trace position given a @@ -320,7 +326,7 @@ def get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): ValueError If `order` is not '1', '2', '3', or a combination of '1', '2', and '3'. """ - spectral_order_index = find_spectral_order_index(refmodel, int(order)) + spectral_order_index = _find_spectral_order_index(refmodel, int(order)) if spectral_order_index < 0: error_message = f"Order {order} is not supported at this time." @@ -335,15 +341,15 @@ def get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): if subarray == 'SUBSTRIP96': y -= 10 # rotated reference trace - x_new, y_new = rotate(x, y, pwcpos - refmodel.meta.pwcpos_cmd, origin, interp=interp) + x_new, y_new = _rotate(x, y, pwcpos - refmodel.meta.pwcpos_cmd, origin, interp=interp) # wavelength associated to trace at given pwcpos value - wavelengths = get_wavelengths(refmodel, x_new, pwcpos, int(order)) + wavelengths = _get_wavelengths(refmodel, x_new, pwcpos, int(order)) return order, x_new, y_new, wavelengths -def extrapolate_to_wavegrid(w_grid, wavelength, quantity): +def _extrapolate_to_wavegrid(w_grid, wavelength, quantity): """ Extrapolates quantities on the right and the left of a given array of quantity @@ -385,7 +391,7 @@ def extrapolate_to_wavegrid(w_grid, wavelength, quantity): return q_grid -def calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=0, maxiter=5, dtol=1e-2): +def _calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=0, maxiter=5, dtol=1e-2): """Compute the 2D wavelength map on the detector. Parameters @@ -483,8 +489,8 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec Array, Array The 2D wavemaps and corresponding 1D spectraces """ - _, order1_x, order1_y, order1_wl = get_soss_traces(refmodel, pwcpos, order='1', subarray=subarray, interp=True) - _, order2_x, order2_y, order2_wl = get_soss_traces(refmodel, pwcpos, order='2', subarray=subarray, interp=True) + _, order1_x, order1_y, order1_wl = _get_soss_traces(refmodel, pwcpos, order='1', subarray=subarray, interp=True) + _, order2_x, order2_y, order2_wl = _get_soss_traces(refmodel, pwcpos, order='2', subarray=subarray, interp=True) # Make wavemap from trace center wavelengths, padding to shape (296, 2088) wavemin = WAVEMAP_WLMIN @@ -493,8 +499,8 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec wave_grid = np.linspace(wavemin, wavemax, nwave) # Extrapolate wavelengths for order 1 trace - xtrace_order1 = extrapolate_to_wavegrid(wave_grid, order1_wl, order1_x) - ytrace_order1 = extrapolate_to_wavegrid(wave_grid, order1_wl, order1_y) + xtrace_order1 = _extrapolate_to_wavegrid(wave_grid, order1_wl, order1_x) + ytrace_order1 = _extrapolate_to_wavegrid(wave_grid, order1_wl, order1_y) spectrace_1 = np.array([xtrace_order1, ytrace_order1, wave_grid]) # Set cutoff for order 2 where it runs off the detector @@ -517,15 +523,15 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec y_o2[o2_cutoff:] = y_o2[o2_cutoff - 1] + m * dx # Extrapolate wavelengths for order 2 trace - xtrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, x_o2) - ytrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, y_o2) + xtrace_order2 = _extrapolate_to_wavegrid(wave_grid, w_o2, x_o2) + ytrace_order2 = _extrapolate_to_wavegrid(wave_grid, w_o2, y_o2) spectrace_2 = np.array([xtrace_order2, ytrace_order2, wave_grid]) # Make wavemap from wavelength solution for order 1 - wavemap_1 = calc_2d_wave_map(wave_grid, xtrace_order1, ytrace_order1, np.zeros_like(xtrace_order1), oversample=1, padding=padsize) + wavemap_1 = _calc_2d_wave_map(wave_grid, xtrace_order1, ytrace_order1, np.zeros_like(xtrace_order1), oversample=1, padding=padsize) # Make wavemap from wavelength solution for order 2 - wavemap_2 = calc_2d_wave_map(wave_grid, xtrace_order2, ytrace_order2, np.zeros_like(xtrace_order2), oversample=1, padding=padsize) + wavemap_2 = _calc_2d_wave_map(wave_grid, xtrace_order2, ytrace_order2, np.zeros_like(xtrace_order2), oversample=1, padding=padsize) # Extrapolate wavemap to FULL frame wavemap_1[:SUBARRAY_YMIN - padsize, :] = wavemap_1[SUBARRAY_YMIN - padsize] diff --git a/jwst/extract_1d/soss_extract/soss_centroids.py b/jwst/extract_1d/soss_extract/soss_centroids.py deleted file mode 100644 index 2474c571d0..0000000000 --- a/jwst/extract_1d/soss_extract/soss_centroids.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import numpy as np -import warnings - -from .soss_utils import robust_polyfit, get_image_dim - -log = logging.getLogger(__name__) -log.setLevel(logging.DEBUG) - - -def center_of_mass(column, ypos, halfwidth): - """Compute a windowed center-of-mass along a column. - - Parameters - ---------- - column : array[float] - The column values on which to compute the windowed center of mass. - ypos : float - The position along the column to center the window on. - halfwidth : int - The half-size of the window in pixels. - - Returns - -------- - ycom : float - The center-of-mass of the pixels within the window. - """ - - # Get the column shape and create a corresponding array of positions. - dimy, = column.shape - ypix = np.arange(dimy) - - # Find the indices of the window. - miny = int(np.fmax(np.around(ypos - halfwidth), 0)) - maxy = int(np.fmin(np.around(ypos + halfwidth + 1), dimy)) - - # Compute the center of mass on the window. - with np.errstate(invalid='ignore'): - ycom = (np.nansum(column[miny:maxy] * ypix[miny:maxy]) / - np.nansum(column[miny:maxy])) - - return ycom - - -def get_centroids_com(scidata_bkg, header=None, mask=None, poly_order=11): - """Determine the x, y coordinates of the trace using a center-of-mass - analysis. Works for either order if there is no contamination, or for - order 1 on a detector where the two orders are overlapping. - - Parameters - ---------- - scidata_bkg : array[float] - A background subtracted observation. - header : astropy.io.fits.Header - The header from one of the SOSS reference files. - mask : array[bool] - A boolean array of the same shape as image. Pixels corresponding to - True values will be masked. - poly_order : None or int - Order of the polynomial to fit to the extracted trace positions. - - Returns - -------- - xtrace : array[float] - The x coordinates of trace as computed from the best fit polynomial. - ytrace : array[float] - The y coordinates of trace as computed from the best fit polynomial. - param : array[float] - The best-fit polynomial parameters. - """ - - # If no mask was given use all pixels. - if mask is None: - mask = np.zeros_like(scidata_bkg, dtype='bool') - - # Call the script that determines the dimensions of the stack. - dimx, dimy, xos, yos, xnative, ynative, padding, refpix_mask = get_image_dim(scidata_bkg, header=header) - - # Replace masked pixel values with NaNs. - scidata_bkg_masked = np.where(mask | ~refpix_mask, np.nan, scidata_bkg) - - # Find centroid - first pass, use all pixels in the column. - - # Normalize each column - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN") - maxvals = np.nanmax(scidata_bkg_masked, axis=0) - scidata_norm = scidata_bkg_masked / maxvals - - # Create 2D Array of pixel positions. - xpix = np.arange(dimx) - ypix = np.arange(dimy) - _, ygrid = np.meshgrid(xpix, ypix) - - # CoM analysis to find initial positions using all rows. - with np.errstate(invalid='ignore'): - ytrace = np.nansum(scidata_norm * ygrid, axis=0) / np.nansum(scidata_norm, axis=0) - ytrace = np.where(np.abs(ytrace) == np.inf, np.nan, ytrace) - - # Second pass - use a windowed CoM at the previous position. - halfwidth = 30 * yos - for icol in range(dimx): - - ycom = center_of_mass(scidata_norm[:, icol], ytrace[icol], halfwidth) - - # If NaN was returned or centroid is out of bounds, we are done. - if not np.isfinite(ycom) or (ycom > (ynative - 1) * yos) or (ycom < 0): - ytrace[icol] = np.nan - continue - - # If the pixel at the centroid is below the local mean we are likely - # mid-way between orders and we should shift the window downward to - # get a reliable centroid for order 1. - irow = int(np.around(ycom)) - miny = int(np.fmax(np.around(ycom) - halfwidth, 0)) - maxy = int(np.fmin(np.around(ycom) + halfwidth + 1, dimy)) - if scidata_norm[irow, icol] < np.nanmean(scidata_norm[miny:maxy, icol]): - ycom = center_of_mass(scidata_norm[:, icol], ycom - halfwidth, halfwidth) - - # If the updated position is too close to the array edge, use NaN. - if not np.isfinite(ycom) or (ycom <= 5 * yos) or (ycom >= (ynative - 6) * yos): - ytrace[icol] = np.nan - continue - - # Update the position if the above checks were successful. - ytrace[icol] = ycom - - # Third pass - fine tuning using a smaller window. - halfwidth = 16 * yos - for icol in range(dimx): - - ytrace[icol] = center_of_mass(scidata_norm[:, icol], ytrace[icol], - halfwidth) - - # Fit the y-positions with a polynomial and use result as true y-positions. - xtrace = np.arange(dimx) - mask = np.isfinite(ytrace) - - # For padded arrays ignore padding for consistency with real data - if padding != 0: - mask = mask & (xtrace >= xos * padding) & (xtrace < (dimx - xos * padding)) - - # If no polynomial order was given return the raw measurements. - if poly_order is None: - param = [] - else: - param = robust_polyfit(xtrace[mask], ytrace[mask], poly_order) - ytrace = np.polyval(param, xtrace) - - return xtrace, ytrace, param diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 7605afedf1..b9de2c7e3e 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -81,14 +81,6 @@ def get_ref_file_args(ref_files): for i, throughput in enumerate(pastasoss_ref.throughputs): throughput_index_dict[throughput.spectral_order] = i - # TODO: check that replacing legacy interp1d with make_interp_spline works right - # this is equivalent to order 3 interpolation with scipy interp1d except that - # fill_value=0.0 has been removed because make_interp_spline does not support it - # was the fill_value ever used? since this is a throughput there's likely nothing wrong with it - # What are the appropriate boundary conditions to pass to make_interp_spline bc_type? - # since expectation is all zeros outside range, first derivative should be zero at boundaries - # one idea is to retain ThroughputSOSS class and make it return a function that calls - # output of make_interp_spline when inside range and returns zero outside range throughput_o1 = ThroughputSOSS(pastasoss_ref.throughputs[throughput_index_dict[1]].wavelength[:], pastasoss_ref.throughputs[throughput_index_dict[1]].throughput[:]) throughput_o2 = ThroughputSOSS(pastasoss_ref.throughputs[throughput_index_dict[2]].wavelength[:], @@ -267,6 +259,7 @@ def get_native_grid_from_trace(ref_files, spectral_order): def get_grid_from_trace(ref_files, spectral_order, n_os=1): """ + TODO: is this partially or fully redundant with atoca_utils.grid_from_map? Make a 1d-grid of the pixels boundary and ready for ATOCA ExtractionEngine, based on the wavelength solution. Parameters diff --git a/jwst/extract_1d/soss_extract/soss_syscor.py b/jwst/extract_1d/soss_extract/soss_syscor.py index 36d66aa4f4..eaeda83b31 100644 --- a/jwst/extract_1d/soss_extract/soss_syscor.py +++ b/jwst/extract_1d/soss_extract/soss_syscor.py @@ -6,63 +6,6 @@ log.setLevel(logging.DEBUG) -def make_profile_mask(ref_2d_profile, threshold=1e-3): - """Build a mask of the trace based on the 2D profile reference file. - - Parameters - ---------- - ref_2d_profile : array[float] - The 2d trace profile reference. - threshold : float - Threshold value for excluding pixels based on ref_2d_profile. - - Returns - ------- - bkg_mask : array[bool] - Pixel mask in the trace based on the 2d profile reference file. - """ - - bkg_mask = (ref_2d_profile > threshold) - - return bkg_mask - - -def aperture_mask(xref, yref, halfwidth, shape): - """Build a mask of the trace based on the trace positions. - - Parameters - ---------- - xref : array[float] - The reference x-positions. - yref : array[float] - The reference y-positions. - halfwidth : float - Size of the aperture mask used when extracting the trace - positions from the data. - shape : Tuple(int, int) - The shape of the array to be masked. - - Returns - ------- - aper_mask : array[bool] - Pixel mask in the trace based on the given trace positions. - """ - - # Create a coordinate grid. - x = np.arange(shape[1]) - y = np.arange(shape[0]) - xx, yy = np.meshgrid(x, y) - - # Interpolate the trace positions onto the grid. - sort = np.argsort(xref) - ytrace = np.interp(x, xref[sort], yref[sort]) - - # Compute the aperture mask. - aper_mask = np.abs(yy - ytrace) > halfwidth - - return aper_mask - - def soss_background(scidata, scimask, bkg_mask=None): """Compute a columnwise background for a SOSS observation. @@ -168,89 +111,3 @@ def make_background_mask(deepstack, width=28): bkg_mask = (deepstack > threshold) | ~np.isfinite(deepstack) return bkg_mask - - -def soss_oneoverf_correction(scidata, scimask, deepstack, bkg_mask=None, - zero_bias=False): - """Compute a column-wise correction to the 1/f noise on the difference image - of an individual SOSS integration (i.e. an individual integration - a deep - image of the same observation). - - Parameters - ---------- - scidata : array[float] - Image of the SOSS trace. - scimask : array[boo] - Boolean mask of pixels to be excluded based on the DQ values. - deepstack : array[float] - Deep image of the trace constructed by combining - individual integrations of the observation. - bkg_mask : array[bool] - Boolean mask of pixels to be excluded because they are in the trace, - typically constructed with make_profile_mask. - zero_bias : bool - If True, the corrections to individual columns will be - adjusted so that their mean is zero. - - Returns - ------- - scidata_cor : array[float] - The 1/f-corrected image - col_cor : array[float] - The column-wise correction values - npix_cor : array[float] - Number of pixels used in each column - bias : float - Net change to the image, if zero_bias was False - """ - - # Check the validity of the input. - data_shape = scidata.shape - - if scimask.shape != data_shape: - msg = 'scidata and scimask must have the same shape.' - log.critical(msg) - raise ValueError(msg) - - if deepstack.shape != data_shape: - msg = 'scidata and deepstack must have the same shape.' - log.critical(msg) - raise ValueError(msg) - - if bkg_mask is not None: - - if bkg_mask.shape != data_shape: - msg = 'scidata and bkg_mask must have the same shape.' - log.critical(msg) - raise ValueError(msg) - - # Subtract the deep stack from the image. - diffimage = scidata - deepstack - - # Combine the masks and create a masked array. - mask = scimask | ~np.isfinite(deepstack) - - if bkg_mask is not None: - mask = mask | bkg_mask - - diffimage_masked = np.ma.array(diffimage, mask=mask) - - # Mask additional pixels using sigma-clipping. - sigclip = SigmaClip(sigma=3, maxiters=None, cenfunc='mean') - diffimage_clipped = sigclip(diffimage_masked, axis=0) - - # Compute the mean for each column and record the number of pixels used. - col_cor = diffimage_clipped.mean(axis=0) - npix_cor = (~diffimage_clipped.mask).sum(axis=0) - - # Compute the net change to the image. - bias = np.nanmean(col_cor) - - # Set the net bias to zero. - if zero_bias: - col_cor = col_cor - bias - - # Apply the 1/f correction to the image. - scidata_cor = scidata - col_cor - - return scidata_cor, col_cor, npix_cor, bias diff --git a/jwst/extract_1d/soss_extract/soss_utils.py b/jwst/extract_1d/soss_extract/soss_utils.py deleted file mode 100644 index 88c8b3212a..0000000000 --- a/jwst/extract_1d/soss_extract/soss_utils.py +++ /dev/null @@ -1,177 +0,0 @@ -import numpy as np -import logging - -log = logging.getLogger(__name__) -log.setLevel(logging.DEBUG) - - -def zero_roll(a, shift): - """Like np.roll but the wrapped around part is set to zero. - Only works along the first axis of the array. - - Parameters - ---------- - a : array - The input array. - shift : int - The number of rows to shift by. - - Returns - ------- - result : array - The array with the rows shifted. - """ - - result = np.zeros_like(a) - if shift >= 0: - result[shift:] = a[:-shift] - else: - result[:shift] = a[-shift:] - - return result - - -def robust_polyfit(x, y, order, maxiter=5, nstd=3.): - """Perform a robust polynomial fit. - - Parameters - ---------- - x : array[float] - x data to fit. - y : array[float] - y data to fit. - order : int - polynomial order to use. - maxiter : int, optional - number of iterations for rejecting outliers. - nstd : float, optional - number of standard deviations to use when rejecting outliers. - - Returns - ------- - param : array[float] - best-fit polynomial parameters. - """ - - mask = np.ones_like(x, dtype='bool') - for niter in range(maxiter): - - # Fit the data and evaluate the best-fit model. - param = np.polyfit(x[mask], y[mask], order) - yfit = np.polyval(param, x) - - # Compute residuals and mask outliers. - res = y - yfit - stddev = np.std(res) - mask = np.abs(res) <= nstd * stddev - - return param - - -def get_image_dim(image, header=None): - """Determine the properties of the image array. - - Parameters - ---------- - image : array[float] - A 2D image of the detector. - header : astropy.io.fits.Header object, optional - The header from one of the SOSS reference files. - - Returns - ------- - dimx, dimy : int - X and Y dimensions of the stack array. - xos, yos : int - Oversampling factors in x and y dimensions of the stack array. - xnative, ynative : int - Size of stack image x and y dimensions, in native pixels. - padding : int - Amount of padding around the image, in native pixels. - refpix_mask : array[bool] - Boolean array indicating which pixels are light-sensitive (True) - and which are reference pixels (False). - """ - - # Dimensions of the subarray. - dimy, dimx = np.shape(image) - - # If no header was passed we have to check all possible sizes. - if header is None: - - # Initialize padding to zero in this case because it is not a reference file. - padding = 0 - - # Assume the stack is a valid SOSS subarray. - # FULL: 2048x2048 or 2040x2040 (working pixels) or multiple if oversampled. - # SUBSTRIP96: 2048x96 or 2040x96 (working pixels) or multiple if oversampled. - # SUBSTRIP256: 2048x256 or 2040x252 (working pixels) or multiple if oversampled. - - # Check if the size of the x-axis is valid. - if (dimx % 2048) == 0: - xnative = 2048 - xos = int(dimx // 2048) - - elif (dimx % 2040) == 0: - xnative = 2040 - xos = int(dimx // 2040) - - else: - log_message = f'Stack X dimension has unrecognized size of {dimx}. Accepts 2048, 2040 or multiple of.' - log.critical(log_message) - raise ValueError(log_message) - - # Check if the y-axis is consistent with the x-axis. - if int(dimy / xos) in [96, 256, 252, 2040, 2048]: - yos = np.copy(xos) - ynative = int(dimy / yos) - - else: - log_message = f'Stack Y dimension ({dimy}) is inconsistent with stack X' \ - f'dimension ({dimx}) for acceptable SOSS arrays' - log.critical(log_message) - raise ValueError(log_message) - - # Create a boolean mask indicating which pixels are not reference pixels. - refpix_mask = np.ones_like(image, dtype='bool') - if xnative == 2048: - # Mask out the left and right columns of reference pixels. - refpix_mask[:, :xos * 4] = False - refpix_mask[:, -xos * 4:] = False - - if ynative == 2048: - # Mask out the top and bottom rows of reference pixels. - refpix_mask[:yos * 4, :] = False - refpix_mask[-yos * 4:, :] = False - - if ynative == 256: - # Mask the top rows of reference pixels. - refpix_mask[-yos * 4:, :] = False - - else: - # Read the oversampling and padding from the header. - padding = int(header['PADDING']) - xos, yos = int(header['OVERSAMP']), int(header['OVERSAMP']) - - # Check that the stack respects its intended format. - if (dimx / xos - 2 * padding) not in [2048]: - log_message = 'The header passed is inconsistent with the X dimension of the stack.' - log.critical(log_message) - raise ValueError(log_message) - else: - xnative = 2048 - - if (dimy / yos - 2 * padding) not in [96, 256, 2048]: - log_message = 'The header passed is inconsistent with the Y dimension of the stack.' - log.critical(log_message) - raise ValueError(log_message) - else: - ynative = int(dimy / yos - 2 * padding) - - # The trace file contains no reference pixels so all pixels are good. - refpix_mask = np.ones_like(image, dtype='bool') - - log.debug('Data dimensions:') - log.debug(f'dimx={dimx}, dimy={dimy}, xos={xos}, yos={yos}, xnative={xnative}, ynative={ynative}') - - return dimx, dimy, xos, yos, xnative, ynative, padding, refpix_mask From 5f70c9c33fc00a1e4146490a18c91572db6ccf11 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 29 Oct 2024 10:40:51 -0400 Subject: [PATCH 07/35] combine BaseOverlap class with ExtractionEngine --- jwst/extract_1d/soss_extract/atoca.py | 201 ++++++-------------- jwst/extract_1d/soss_extract/atoca_utils.py | 6 +- 2 files changed, 60 insertions(+), 147 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index e8b5c20301..c9004800be 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -15,7 +15,6 @@ import warnings from scipy.sparse import issparse, csr_matrix, diags from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning -from scipy.interpolate import interp1d # Local imports. from . import atoca_utils @@ -33,8 +32,12 @@ def __init__(self, message): super().__init__(self.message) -class _BaseOverlap: - """Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep). +class ExtractionEngine(): + """ + Run the ATOCA algorithm (Darveau-Bernier 2021, in prep). + + TODO: fix the below, which came from _BaseOverlap class + Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep). Used to perform an overlapping extraction of the form: (B_T * B) * f = (data/sig)_T * B where B is a matrix and f is an array. @@ -47,8 +50,19 @@ class _BaseOverlap: The classes inheriting from this class should specify the methods get_w which computes the 'k' associated to each pixel 'i'. These depends of the type of interpolation used. - """ + This version models the pixels of the detector using an oversampled trapezoidal integration. + TODO: Merge with BaseOverlap class for readability + TODO: the following arguments can be simplified + mask - never used, always superseded by mask_trace_profile + wave_grid - always passed in explicitly; no need to have a default + wave_bounds - never used, default is always computed from wave_map + n_os - apparently never used because wave_grid is always explicit + + + TODO: I don't understand why data is not a required argument. If the data are only + needed when ExtractionEngine.__call__ happens, then the data should be input there somehow + """ # The desired data-type for computations, e.g., 'float32'. 'float64' is recommended. dtype = 'float64' @@ -59,18 +73,19 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, """ Parameters ---------- - wave_map : (N_ord, N, M) list or array of 2-D arrays - A list or array of the central wavelength position for each - order on the detector. It must have the same (N, M) as `data`. trace_profile : (N_ord, N, M) list or array of 2-D arrays A list or array of the spatial profile for each order - on the detector. It must have the same (N, M) as `data`. + on the detector. It has to have the same (N, M) as `data`. + wave_map : (N_ord, N, M) list or array of 2-D arrays + A list or array of the central wavelength position for each + order on the detector. + It has to have the same (N, M) as `data`. throughput : (N_ord [, N_k]) list of array or callable A list of functions or array of the throughput at each order. If callable, the functions depend on the wavelength. If array, projected on `wave_grid`. kernels : array, callable or sparse matrix - Convolution kernel to be applied on the spectrum (f_k) for each order. + Convolution kernel to be applied on spectrum (f_k) for each orders. Can be array of the shape (N_ker, N_k_c). Can be a callable with the form f(x, x0) where x0 is the position of the center of the kernel. In this case, it must @@ -81,21 +96,24 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, If sparse, the shape has to be (N_k_c, N_k) and it will be used directly. N_ker is the length of the effective kernel and N_k_c is the length of the spectrum (f_k) convolved. - global_mask : (N, M) array_like boolean, optional + data : (N, M) array_like, optional + A 2-D array of real values representing the detector image. + error : (N, M) array_like, optional + Estimate of the error on each pixel. Default is one everywhere. + mask : (N, M) array_like boolean, optional Boolean Mask of the detector pixels to mask for every extraction. Should not be related to a specific order (if so, use `mask_trace_profile` instead). mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional - A list or array of the pixels that need to be used for extraction, + A list or array of the pixel that need to be used for extraction, for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. If not given, `threshold` will be applied on spatial profiles to define the masks. - orders : list, optional: + orders : list, optional List of orders considered. Default is orders = [1, 2] wave_grid : (N_k) array_like, optional The grid on which f(lambda) will be projected. - Default is a grid from `utils.get_soss_grid`. - `n_os` will be passed to this function. + Default still has to be improved. wave_bounds : list or array-like (N_ord, 2), optional - Boundary wavelengths covered by each order. + Boundary wavelengths covered by each orders. Default is the wavelength covered by `wave_map`. n_os : int, optional Oversampling rate. If `wave_grid`is None, it will be used to @@ -111,6 +129,17 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, If dictionary, the same c_kwargs will be used for each order. """ + # Get wavelength at the boundary of each pixel + wave_p, wave_m = [], [] + for wave in wave_map: # For each order + lp, lm = atoca_utils.get_wave_p_or_m(wave) # Lambda plus or minus + # Make sure it is the good precision + wave_p.append(lp.astype(self.dtype)) + wave_m.append(lm.astype(self.dtype)) + + # Save values + self.wave_p, self.wave_m = wave_p, wave_m + # If no orders specified extract on orders 1 and 2. if orders is None: orders = [1, 2] @@ -227,7 +256,6 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, self.tikho_mat = None self.w_t_wave_c = None - return def get_attributes(self, *args, i_order=None): """Return list of attributes @@ -258,6 +286,7 @@ def get_attributes(self, *args, i_order=None): def update_wave_map(self, wave_map): """Update internal wave_map + TODO: can this be removed? Parameters ---------- wave_map : array[float] @@ -270,10 +299,10 @@ def update_wave_map(self, wave_map): dtype = self.dtype self.wave_map = [wave_n.astype(dtype).copy() for wave_n in wave_map] - return def update_trace_profile(self, trace_profile): """Update internal trace_profiles + TODO: can this be removed? Parameters ---------- trace_profile : array[float] @@ -288,10 +317,10 @@ def update_trace_profile(self, trace_profile): # Update the trace_profile profile. self.trace_profile = [trace_profile_n.astype(dtype).copy() for trace_profile_n in trace_profile] - return def update_throughput(self, throughput): """Update internal throughput values + TODO: can this be removed? Parameters ---------- throughput : array[float] or callable @@ -325,10 +354,10 @@ def update_throughput(self, throughput): # Set the attribute to the new values. self.throughput = throughput_new - return def update_kernels(self, kernels, c_kwargs): """Update internal kernels + TODO: can this be removed? Parameters ---------- kernels : array, callable or sparse matrix @@ -377,30 +406,6 @@ def update_kernels(self, kernels, c_kwargs): self.kernels = kernels_new - def get_mask_wave(self, i_order): - """Generate mask bounded by limits of wavelength grid - Parameters - ---------- - i_order : int - Order to select the wave_map on which a mask - will be generated - - Returns - ------- - array[bool] - A mask with True where wave_map is outside the bounds - of wave_grid - """ - - wave = self.wave_map[i_order] - imin, imax = self.i_bounds[i_order] - wave_min = self.wave_grid[imin] - wave_max = self.wave_grid[imax - 1] - - mask = (wave <= wave_min) | (wave >= wave_max) - - return mask - def _get_masks(self, global_mask): """Compute a general mask on the detector and for each order. Depends on the spatial profile, the wavelength grid @@ -486,7 +491,6 @@ def update_mask(self, mask): # Re-compute weights self.weights, self.weights_k_idx = self.compute_weights() - return def _get_i_bnds(self, wave_bounds=None): """Define wavelength boundaries for each order using the order's mask. @@ -550,7 +554,6 @@ def update_i_bnds(self): # Update attribute. self.i_bounds = i_bnds_new - return def wave_grid_c(self, i_order): """Return wave_grid for the convolved flux at a given order. @@ -560,12 +563,6 @@ def wave_grid_c(self, i_order): return self.wave_grid[index] - def get_w(self, i_order): - """Dummy method to init this class - TODO: so is this an abstract base class? is it ever actually subclassed more than once? - """ - - return np.array([]), np.array([]) def compute_weights(self): """ @@ -583,9 +580,10 @@ def compute_weights(self): # Init lists weights, weights_k_idx = [], [] - for i_order in range(self.n_orders): # For each orders + for i_order in range(self.n_orders): - weights_n, k_idx_n = self.get_w(i_order) # Compute weights + # Compute weights + weights_n, k_idx_n = self.get_w(i_order) # Convert to sparse matrix # First get the dimension of the convolved grid @@ -608,7 +606,6 @@ def _set_w_t_wave_c(self, i_order, product): # Assign value self.w_t_wave_c[i_order] = product.copy() - return def grid_from_map(self, i_order=0): """Return the wavelength grid and the columns associated @@ -962,7 +959,6 @@ def set_tikho_matrix(self, t_mat=None, t_mat_func=None, fargs=None, fkwargs=None # Set attribute self.tikho_mat = t_mat - return def get_tikho_matrix(self, **kwargs): """ @@ -1276,9 +1272,8 @@ def compute_likelihood(self, spectrum=None, same=False): # Compute the log-likelihood for the spectrum. with np.errstate(divide='ignore'): logl = (model - data) / error - logl = -np.nansum((logl[~mask])**2) + return -np.nansum((logl[~mask])**2) - return logl @staticmethod def _solve(matrix, result): @@ -1387,87 +1382,6 @@ def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): return spectrum -class ExtractionEngine(_BaseOverlap): - """ - Run the ATOCA algorithm (Darveau-Bernier 2021, in prep). - - This version models the pixels of the detector using an oversampled trapezoidal integration. - """ - - def __init__(self, wave_map, trace_profile, *args, **kwargs): - """ - Parameters - ---------- - trace_profile : (N_ord, N, M) list or array of 2-D arrays - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - wave_map : (N_ord, N, M) list or array of 2-D arrays - A list or array of the central wavelength position for each - order on the detector. - It has to have the same (N, M) as `data`. - throughput : (N_ord [, N_k]) list of array or callable - A list of functions or array of the throughput at each order. - If callable, the functions depend on the wavelength. - If array, projected on `wave_grid`. - kernels : array, callable or sparse matrix - Convolution kernel to be applied on spectrum (f_k) for each orders. - Can be array of the shape (N_ker, N_k_c). - Can be a callable with the form f(x, x0) where x0 is - the position of the center of the kernel. In this case, it must - return a 1D array (len(x)), so a kernel value - for each pairs of (x, x0). If array or callable, - it will be passed to `convolution.get_c_matrix` function - and the `c_kwargs` can be passed to this function. - If sparse, the shape has to be (N_k_c, N_k) and it will - be used directly. N_ker is the length of the effective kernel - and N_k_c is the length of the spectrum (f_k) convolved. - data : (N, M) array_like, optional - A 2-D array of real values representing the detector image. - error : (N, M) array_like, optional - Estimate of the error on each pixel. Default is one everywhere. - mask : (N, M) array_like boolean, optional - Boolean Mask of the detector pixels to mask for every extraction. - Should not be related to a specific order (if so, use `mask_trace_profile` instead). - mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional - A list or array of the pixel that need to be used for extraction, - for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. - If not given, `threshold` will be applied on spatial profiles to define the masks. - orders : list, optional - List of orders considered. Default is orders = [1, 2] - wave_grid : (N_k) array_like, optional - The grid on which f(lambda) will be projected. - Default still has to be improved. - wave_bounds : list or array-like (N_ord, 2), optional - Boundary wavelengths covered by each orders. - Default is the wavelength covered by `wave_map`. - n_os : int, optional - Oversampling rate. If `wave_grid`is None, it will be used to - generate a grid. Default is 2. - threshold : float, optional: - The contribution of any order on a pixel is considered significant if - its estimated spatial profile is greater than this threshold value. - If it is not properly modeled (not covered by the wavelength grid), - it will be masked. Default is 1e-3. - c_kwargs : list of N_ord dictionaries or dictionary, optional - Inputs keywords arguments to pass to - `convolution.get_c_matrix` function for each order. - If dictionary, the same c_kwargs will be used for each order. - """ - - # Get wavelength at the boundary of each pixel - wave_p, wave_m = [], [] - for wave in wave_map: # For each order - lp, lm = atoca_utils.get_wave_p_or_m(wave) # Lambda plus or minus - # Make sure it is the good precision - wave_p.append(lp.astype(self.dtype)) - wave_m.append(lm.astype(self.dtype)) - - # Save values - self.wave_p, self.wave_m = wave_p, wave_m - - # Init upper class - super().__init__(wave_map, trace_profile, *args, **kwargs) - def _get_lo_hi(self, grid, i_order): """ Find the lowest (lo) and highest (hi) index @@ -1509,7 +1423,8 @@ def _get_lo_hi(self, grid, i_order): ma = mask_ord[~mask] lo[ma], hi[ma] = -1, -2 - return lo, hi + return lo, hi + def get_mask_wave(self, i_order): """Generate mask bounded by limits of wavelength grid @@ -1531,15 +1446,13 @@ def get_mask_wave(self, i_order): wave_min = self.wave_grid[i_bnds[0]] wave_max = self.wave_grid[i_bnds[1] - 1] - mask = (wave_m < wave_min) | (wave_p > wave_max) + return (wave_m < wave_min) | (wave_p > wave_max) - return mask def get_w(self, i_order): """Compute integration weights for each grid points and each pixels. Depends on the order `n`. - TODO: is this the same order as the spectral order? if so, can we ignore - a bunch of the cases in this function and just keep orders 1 and 2? + TODO: what is this doing? where can we find the math? Parameters ---------- @@ -1616,7 +1529,7 @@ def get_w(self, i_order): # Generate array of all k_i. Set to max value of uint16 if not valid k_n = atoca_utils.arange_2d(k_first, k_last + 1) - bad = k_n == np.iinfo(k_n.dtype).max + bad = k_n == -1 # Number of valid k per pixel n_k = np.sum(~bad, axis=-1) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 2502ef8967..811d7adacb 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -6,7 +6,6 @@ @authors: Antoine Darveau-Bernier, Geert Jan Talens """ -import warnings import numpy as np from scipy.sparse import diags, csr_matrix from scipy.sparse.linalg import spsolve @@ -35,7 +34,7 @@ def arange_2d(starts, stops): Returns ------- out : array[uint16] - 2D array of ranges with invalid values set to max uint16 value, 65535 + 2D array of ranges with invalid values set to -1 """ if starts.shape != stops.shape: msg = ('Shapes of starts and stops are not compatible, ' @@ -54,7 +53,7 @@ def arange_2d(starts, stops): # Initialize the output arrays with invalid value nrows = len(stops) ncols = np.amax(lengths) - out = np.ones((nrows, ncols), dtype=np.uint16)*np.iinfo(np.uint16).max + out = np.ones((nrows, ncols), dtype=np.int16)*-1 # Compute the indices. for irow in range(nrows): @@ -557,6 +556,7 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, """ TODO: can this be a class? e.g., class AdaptiveGrid? TODO: why aren't any of the same helper functions used here as in get_soss_grid? + q: why are there multiple grids passed in here in the first place Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. From c4de6673e5cbf81f8f04e24ff7f2c8efa0d16f72 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 30 Oct 2024 14:18:57 -0400 Subject: [PATCH 08/35] make data, err required for ExtractionEngine calls, remove them as engine attributes --- jwst/extract_1d/soss_extract/atoca.py | 816 +++++------------- jwst/extract_1d/soss_extract/soss_extract.py | 32 +- .../soss_extract/tests/test_atoca.py | 35 + 3 files changed, 257 insertions(+), 626 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/test_atoca.py diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index c9004800be..b9780c4c47 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -32,11 +32,11 @@ def __init__(self, message): super().__init__(self.message) -class ExtractionEngine(): +class ExtractionEngine: """ Run the ATOCA algorithm (Darveau-Bernier 2021, in prep). - TODO: fix the below, which came from _BaseOverlap class + TODO: merge the docstring below, which came from _BaseOverlap class, into this docstring Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep). Used to perform an overlapping extraction of the form: (B_T * B) * f = (data/sig)_T * B @@ -52,38 +52,30 @@ class ExtractionEngine(): These depends of the type of interpolation used. This version models the pixels of the detector using an oversampled trapezoidal integration. - TODO: Merge with BaseOverlap class for readability - TODO: the following arguments can be simplified - mask - never used, always superseded by mask_trace_profile - wave_grid - always passed in explicitly; no need to have a default - wave_bounds - never used, default is always computed from wave_map - n_os - apparently never used because wave_grid is always explicit - - - TODO: I don't understand why data is not a required argument. If the data are only - needed when ExtractionEngine.__call__ happens, then the data should be input there somehow """ # The desired data-type for computations, e.g., 'float32'. 'float64' is recommended. dtype = 'float64' def __init__(self, wave_map, trace_profile, throughput, kernels, - orders=None, global_mask=None, mask_trace_profile=None, - wave_grid=None, wave_bounds=None, n_os=2, - threshold=1e-3, c_kwargs=None): + wave_grid, mask_trace_profile, + orders=[1,2], threshold=1e-3, c_kwargs=None): """ Parameters ---------- - trace_profile : (N_ord, N, M) list or array of 2-D arrays - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. wave_map : (N_ord, N, M) list or array of 2-D arrays A list or array of the central wavelength position for each order on the detector. It has to have the same (N, M) as `data`. + + trace_profile : (N_ord, N, M) list or array of 2-D arrays + A list or array of the spatial profile for each order + on the detector. It has to have the same (N, M) as `data`. + throughput : (N_ord [, N_k]) list of array or callable A list of functions or array of the throughput at each order. If callable, the functions depend on the wavelength. If array, projected on `wave_grid`. + kernels : array, callable or sparse matrix Convolution kernel to be applied on spectrum (f_k) for each orders. Can be array of the shape (N_ker, N_k_c). @@ -96,161 +88,91 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, If sparse, the shape has to be (N_k_c, N_k) and it will be used directly. N_ker is the length of the effective kernel and N_k_c is the length of the spectrum (f_k) convolved. - data : (N, M) array_like, optional - A 2-D array of real values representing the detector image. - error : (N, M) array_like, optional - Estimate of the error on each pixel. Default is one everywhere. - mask : (N, M) array_like boolean, optional - Boolean Mask of the detector pixels to mask for every extraction. - Should not be related to a specific order (if so, use `mask_trace_profile` instead). - mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional + + wave_grid : (N_k) array_like, required. + The grid on which f(lambda) will be projected. + + mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], required. A list or array of the pixel that need to be used for extraction, for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. - If not given, `threshold` will be applied on spatial profiles to define the masks. + orders : list, optional List of orders considered. Default is orders = [1, 2] - wave_grid : (N_k) array_like, optional - The grid on which f(lambda) will be projected. - Default still has to be improved. - wave_bounds : list or array-like (N_ord, 2), optional - Boundary wavelengths covered by each orders. - Default is the wavelength covered by `wave_map`. - n_os : int, optional - Oversampling rate. If `wave_grid`is None, it will be used to - generate a grid. Default is 2. + threshold : float, optional: The contribution of any order on a pixel is considered significant if its estimated spatial profile is greater than this threshold value. If it is not properly modeled (not covered by the wavelength grid), it will be masked. Default is 1e-3. + c_kwargs : list of N_ord dictionaries or dictionary, optional Inputs keywords arguments to pass to `convolution.get_c_matrix` function for each order. If dictionary, the same c_kwargs will be used for each order. """ - # Get wavelength at the boundary of each pixel - wave_p, wave_m = [], [] - for wave in wave_map: # For each order - lp, lm = atoca_utils.get_wave_p_or_m(wave) # Lambda plus or minus - # Make sure it is the good precision - wave_p.append(lp.astype(self.dtype)) - wave_m.append(lm.astype(self.dtype)) - - # Save values - self.wave_p, self.wave_m = wave_p, wave_m - - # If no orders specified extract on orders 1 and 2. - if orders is None: - orders = [1, 2] - - ########################### - # Save basic parameters - ########################### - - # Spectral orders and number of orders. - self.data_shape = wave_map[0].shape - self.orders = orders - self.n_orders = len(orders) + # Set the attributes and ensure everything has correct dtype + self.wave_map = np.array(wave_map).astype(self.dtype) + self.trace_profile = np.array(trace_profile).astype(self.dtype) + self.mask_trace_profile = np.array(mask_trace_profile).astype(bool) self.threshold = threshold + self.data_shape = self.wave_map[0].shape - # Raise error if the number of orders is not consistent. - if self.n_orders != len(wave_map): - msg = ("The number of orders specified {} and the number of " - "wavelength maps provided {} do not match.") - log.critical(msg.format(self.n_orders, len(wave_map))) - raise ValueError(msg.format(self.n_orders, len(wave_map))) - - # Detector image. - self.data = np.full(self.data_shape, fill_value=np.nan) - - # Error map of each pixels. - self.error = np.ones(self.data_shape) - - # Set all reference file quantities to None. - self.wave_map = None - self.trace_profile = None - self.throughput = None - self.kernels = None - - # Set the wavelength map and trace_profile for each order. - self.update_wave_map(wave_map) - self.update_trace_profile(trace_profile) - - # Set the mask based on trace profiles and save - if mask_trace_profile is None: - # No mask (False everywhere) - log.warning('mask_trace_profile was not given. All detector pixels will be modeled. ' - 'It is preferable to limit the number of modeled pixels by specifying the region ' - 'of interest with mask_trace_profile.') - mask_trace_profile = np.array([np.zeros(self.data_shape, dtype=bool) for _ in orders]) - self.mask_trace_profile = mask_trace_profile - - # Generate a wavelength grid if none was provided. - if wave_grid is None: - if self.n_orders == 2: - wave_grid = atoca_utils.get_soss_grid(wave_map, trace_profile, n_os=n_os) - else: - wave_grid, _ = self.grid_from_map() - else: - # Check if the input wave_grid is sorted and strictly increasing. - is_sorted = (np.diff(wave_grid) > 0).all() - - # If not, sort it and make it unique - if not is_sorted: - log.warning('`wave_grid` is not strictly increasing. It will be sorted and unique.') - wave_grid = np.unique(wave_grid) - - # Set the wavelength grid and its size. + # Set wave_grid. Ensure it is sorted and strictly increasing. + is_sorted = (np.diff(wave_grid) > 0).all() + if not is_sorted: + log.warning('`wave_grid` is not strictly increasing. It will be sorted and made unique.') + wave_grid = np.unique(wave_grid) self.wave_grid = wave_grid.astype(self.dtype).copy() self.n_wavepoints = len(wave_grid) - # Set the throughput for each order. - self.update_throughput(throughput) - - ################################### - # Build detector mask - ################################### + # Get wavelengths at the boundaries of each pixel for all orders + wave_p, wave_m = [], [] + for wave in self.wave_map: + lp, lm = atoca_utils.get_wave_p_or_m(wave) + wave_p.append(lp) + wave_m.append(lm) + self.wave_p = np.array(wave_p).astype(self.dtype) + self.wave_m = np.array(wave_m).astype(self.dtype) + + # Set orders and ensure that the number of orders is consistent with wave_map length + self.orders = orders + self.n_orders = len(self.orders) + if self.n_orders != len(self.wave_map): + msg = ("The number of orders specified {} and the number of " + "wavelength maps provided {} do not match.") + log.critical(msg.format(self.n_orders, len(self.wave_map))) + raise ValueError(msg.format(self.n_orders, len(self.wave_map))) - # Assign a first estimate of i_bounds to be able to compute mask. - self.i_bounds = [[0, len(wave_grid)] for _ in range(self.n_orders)] + # Set a first estimate of i_bounds to estimate mask + self.i_bounds = [[0, len(self.wave_grid)] for _ in range(self.n_orders)] - # First estimate of a global mask and masks for each orders - self.mask, self.mask_ord = self._get_masks(global_mask) + # Estimate a global mask and masks for each orders + self.mask, self.mask_ord = self._get_masks() + # Save mask here as the general mask, since `mask` attribute can be changed. + self.general_mask = self.mask.copy() # Ensure there are adequate good pixels left in each order good_pixels_in_order = np.sum(np.sum(~self.mask_ord, axis=-1), axis=-1) min_good_pixels = 25 # hard-code to qualitatively reasonable value if np.any(good_pixels_in_order < min_good_pixels): - raise MaskOverlapError('At least one order has no valid pixels (mask_trace_profile and mask_wave do not overlap)') - - # Correct i_bounds if it was not specified - self.i_bounds = self._get_i_bnds(wave_bounds) + msg = (f'At least one order has less than {min_good_pixels} valid pixels. ' + '(mask_trace_profile and mask_wave have insufficient overlap') + raise MaskOverlapError(msg) - # Re-build global mask and masks for each orders - self.mask, self.mask_ord = self._get_masks(global_mask) + # Update i_bounds based on masked wavelengths + self.i_bnds = self._get_i_bnds() - # Save mask here as the general mask, - # since `mask` attribute can be changed. - self.general_mask = self.mask.copy() + # if throughput is given as callable, turn it into an array of proper shape + self.update_throughput(throughput) - #################################### - # Build convolution matrix - #################################### - self.update_kernels(kernels, c_kwargs) - - ############################# - # Compute integration weights - ############################# - # The weights depend on the integration method used solve - # the integral of the flux over a pixel and are encoded - # in the class method `get_w()`. + # turn kernels into sparse matrix + self._create_kernels(kernels, c_kwargs) + + # Compute integration weights. see method self.get_w() for details. self.weights, self.weights_k_idx = self.compute_weights() - ######################### - # Save remaining inputs - ######################### - # Init the pixel mapping (b_n) matrices. Matrices that transforms the 1D spectrum to a the image pixels. + # Initialize the pixel mapping (b_n) matrices self.pixel_mapping = [None for _ in range(self.n_orders)] self.i_grid = None self.tikho_mat = None @@ -264,6 +186,7 @@ def get_attributes(self, *args, i_order=None): ---------- args : str or list[str] All attributes to return. + i_order : None or int, optional Index of order to extract. If specified, it will be applied to all attributes in args, so it cannot @@ -284,52 +207,15 @@ def get_attributes(self, *args, i_order=None): return out - def update_wave_map(self, wave_map): - """Update internal wave_map - TODO: can this be removed? - Parameters - ---------- - wave_map : array[float] - Wavelength maps for each order - - Returns - ------- - None - """ - dtype = self.dtype - self.wave_map = [wave_n.astype(dtype).copy() for wave_n in wave_map] - - - def update_trace_profile(self, trace_profile): - """Update internal trace_profiles - TODO: can this be removed? - Parameters - ---------- - trace_profile : array[float] - Trace profiles for each order - - Returns - ------- - None - """ - dtype = self.dtype - - # Update the trace_profile profile. - self.trace_profile = [trace_profile_n.astype(dtype).copy() for trace_profile_n in trace_profile] - def update_throughput(self, throughput): """Update internal throughput values - TODO: can this be removed? + Parameters ---------- throughput : array[float] or callable Throughput values for each order, given either as an array or as a callable function with self.wave_grid as input. - - Returns - ------- - None """ # Update the throughput values. @@ -352,43 +238,36 @@ def update_throughput(self, throughput): raise ValueError(msg) # Set the attribute to the new values. - self.throughput = throughput_new + self.throughput = np.array(throughput_new).astype(self.dtype) - def update_kernels(self, kernels, c_kwargs): - """Update internal kernels - TODO: can this be removed? + def _create_kernels(self, kernels, c_kwargs): + """Make sparse matrix from input kernels + Parameters ---------- kernels : array, callable or sparse matrix Convolution kernel to be applied on the spectrum (f_k) for each order. + c_kwargs : list of N_ord dictionaries or dictionary, optional Inputs keywords arguments to pass to `convolution.get_c_matrix` function for each order. If dictionary, the same c_kwargs will be used for each order. - - Returns - ------- - None """ - # Check the c_kwargs inputs - # If not given + # If c_kwargs not given, use the kernels min_value attribute. + # It is a way to make sure that the full kernel is used. if c_kwargs is None: - # Then use the kernels min_value attribute. - # It is a way to make sure that the full kernel - # is used. c_kwargs = [] for ker in kernels: - # If the min_value not specified, then - # simply take the get_c_matrix defaults try: kwargs_ker = {'thresh': ker.min_value} except AttributeError: + # take the get_c_matrix defaults kwargs_ker = dict() c_kwargs.append(kwargs_ker) - # ...or same for each orders if only a dictionary was given + # ...or same for all orders if only a dictionary was given elif isinstance(c_kwargs, dict): c_kwargs = [c_kwargs for _ in kernels] @@ -406,21 +285,15 @@ def update_kernels(self, kernels, c_kwargs): self.kernels = kernels_new - def _get_masks(self, global_mask): + def _get_masks(self): """Compute a general mask on the detector and for each order. - Depends on the spatial profile, the wavelength grid - and the user defined mask (optional). These are all specified - when initializing the object. - - Parameters - ---------- - global_mask : array[bool] - Boolean mask of the detector pixels to mask for every extraction. + Depends on the trace profile and the wavelength grid. Returns ------- general_mask : array[bool] Mask that combines global_mask, wavelength mask, trace_profile mask + mask_ord : array[bool] Mask applied to each order """ @@ -430,18 +303,11 @@ def _get_masks(self, global_mask): needed_attr = self.get_attributes(*args) threshold, n_orders, mask_trace_profile, trace_profile = needed_attr - # Convert list to array (easier for coding) - mask_trace_profile = np.array(mask_trace_profile) - # Mask pixels not covered by the wavelength grid. mask_wave = np.array([self.get_mask_wave(i_order) for i_order in range(n_orders)]) # Apply user defined mask. - if global_mask is None: - mask_ord = np.any([mask_trace_profile, mask_wave], axis=0) - else: - mask = [global_mask for _ in range(n_orders)] # For each orders - mask_ord = np.any([mask_trace_profile, mask_wave, mask], axis=0) + mask_ord = np.any([mask_trace_profile, mask_wave], axis=0) # Find pixels that are masked in each order. general_mask = np.all(mask_ord, axis=0) @@ -459,6 +325,7 @@ def _get_masks(self, global_mask): return general_mask, mask_ord + def update_mask(self, mask): """Update `mask` attribute by combining the `general_mask` attribute with the input `mask`. Every time the mask is @@ -470,64 +337,40 @@ def update_mask(self, mask): mask : array[bool] New mask to be combined with internal general_mask and saved in self.mask. - - Returns - ------- - None """ - - # Get general mask - general_mask = self.general_mask - - # Complete with the input mask - new_mask = (general_mask | mask) - - # Update attribute - self.mask = new_mask - - # Correct i_bounds if it was not specified - # self.update_i_bnds() + self.mask = (self.general_mask | mask) # Re-compute weights self.weights, self.weights_k_idx = self.compute_weights() - def _get_i_bnds(self, wave_bounds=None): - """Define wavelength boundaries for each order using the order's mask. - Parameters - ---------- - wave_bounds : list[float], optional - Minimum and maximum values of masked wavelength map. If not given, - calculated from internal wave_map and mask_ord for each order. + def _get_i_bnds(self): + """Define wavelength boundaries for each order using the order's mask + and the wavelength map. + Returns ------- list[float] Wavelength boundaries for each order """ - wave_grid = self.wave_grid - i_bounds = self.i_bounds - - # Check if wave_bounds given - if wave_bounds is None: - wave_bounds = [] - for i in range(self.n_orders): - wave = self.wave_map[i][~self.mask_ord[i]] - wave_bounds.append([wave.min(), wave.max()]) + # Figure out boundary wavelengths + wave_bounds = [] + for i in range(self.n_orders): + wave = self.wave_map[i][~self.mask_ord[i]] + wave_bounds.append([wave.min(), wave.max()]) - # What we need is the boundary position - # on the wavelength grid. + # Determine the boundary position on the wavelength grid. i_bnds_new = [] - for bounds, i_bnds in zip(wave_bounds, i_bounds): + for bounds, i_bnds in zip(wave_bounds, self.i_bounds): - a = np.min(np.where(wave_grid >= bounds[0])[0]) - b = np.max(np.where(wave_grid <= bounds[1])[0]) + 1 + print(bounds, i_bnds) + a = np.min(np.where(self.wave_grid >= bounds[0])[0]) + b = np.max(np.where(self.wave_grid <= bounds[1])[0]) + 1 # Take the most restrictive bound a = np.maximum(a, i_bnds[0]) b = np.minimum(b, i_bnds[1]) - - # Keep value i_bnds_new.append([a, b]) return i_bnds_new @@ -603,7 +446,6 @@ def _set_w_t_wave_c(self, i_order, product): if self.w_t_wave_c is None: self.w_t_wave_c = [[] for _ in range(self.n_orders)] - # Assign value self.w_t_wave_c[i_order] = product.copy() @@ -616,73 +458,19 @@ def grid_from_map(self, i_order=0): wave_map, trace_profile = self.get_attributes(*attrs, i_order=i_order) wave_grid, icol = atoca_utils._grid_from_map(wave_map, trace_profile) - wave_grid = wave_grid.astype(self.dtype) - return wave_grid, icol - def estimate_noise(self, i_order=0, data=None, error=None, mask=None): - """Relative noise estimate over columns. - - Parameters - ---------- - i_order : int, optional - index of diffraction order. Default is 0 - data : 2d array, optional - map of the detector image - Default is `self.data`. - error : 2d array, optional - map of the estimate of the detector noise. - Default is `self.sig` - mask : 2d array, optional - Bool map of the masked pixels for order `i_order`. - Default is `self.mask_ord[i_order]` - Returns - ------ - wave_grid : array[float] - The wavelength grid. - noise : array[float] - The associated noise array. + def get_pixel_mapping(self, i_order, error=None, same=False, quick=False): """ - - # Use object attributes if not given - if data is None: - data = self.data - - if error is None: - error = self.error - - if mask is None: - mask = self.mask_ord[i_order] - - # Compute noise estimate only on the trace (mask the rest) - noise = np.ma.array(error, mask=mask) - - # RMS over columns - noise = np.sqrt((noise**2).sum(axis=0)) - - # Relative - noise /= np.ma.array(data, mask=mask).sum(axis=0) - - # Convert to array with nans - noise = noise.filled(fill_value=np.nan) - - # Get associated wavelengths - wave_grid, i_col = self.grid_from_map(i_order) - - # Return sorted according to wavelengths - return wave_grid, noise[i_col] - - def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): - """Compute the matrix `b_n = (P/sig).w.T.lambda.c_n` , + Compute the matrix `b_n = (P/sig).w.T.lambda.c_n` , where `P` is the spatial profile matrix (diag), `w` is the integrations weights matrix, `T` is the throughput matrix (diag), `lambda` is the convolved wavelength grid matrix (diag), `c_n` is the convolution kernel. - The model of the detector at order n (`model_n`) - is given by the system: + The model of the detector at order n (`model_n`) is given by the system: model_n = b_n.c_n.f , where f is the incoming flux projected on the wavelength grid. This methods updates the `b_n_list` attribute. @@ -691,15 +479,16 @@ def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): ---------- i_order: integer Label of the order (depending on the initiation of the object). + + error: (N, M) array_like or None, optional. + Estimate of the error on each pixel. Same shape as `data`. + If None, the error is set to 1, which means the method will return + b_n instead of b_n/sigma. Default is None. + same: bool, optional Do not recompute b_n. Take the last b_n computed. Useful to speed up code. Default is False. - error: bool or (N, M) array_like, optional - If 2-d array, `sig` is the new error estimation map. - It is the same shape as `sig` initiation input. If bool, - whether to apply sigma or not. The method will return - b_n/sigma if True or array_like and b_n if False. If True, - the default object attribute `sig` will be used. + quick: bool, optional If True, only perform one matrix multiplication instead of the whole system: (P/sig).(w.T.lambda.c_n) @@ -721,17 +510,9 @@ def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): else: # Special treatment for error map # Can be bool or array. - if error is False: + if error is None: # Sigma will have no effect error = np.ones(self.data_shape) - else: - if error is not True: - # Sigma must be an array so - # update object attribute - self.error = error.copy() - - # Take sigma from object - error = self.error # Get needed attributes ... attrs = ['wave_grid', 'mask'] @@ -775,34 +556,21 @@ def get_pixel_mapping(self, i_order, same=False, error=True, quick=False): return pixel_mapping - def build_sys(self, data=None, error=True, mask=None, trace_profile=None, throughput=None): - """Build linear system arising from the logL maximisation. + + def build_sys(self, data, error): + """ + Build linear system arising from the logL maximisation. TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. There will be only one matrix multiplication: (P/sig).(w.T.lambda.c_n). + Parameters ---------- - data : (N, M) array_like, optional + data : (N, M) array_like A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error : bool or (N, M) array_like, optional + + error : (N, M) array_like Estimate of the error on each pixel. - If 2-d array, `sig` is the new error estimation map. - It is the same shape as `sig` initiation input. If bool, - whether to apply sigma or not. The method will return - b_n/sigma if True or array_like and b_n if False. If True, - the default object attribute `sig` will be use. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` Returns ------ @@ -810,7 +578,7 @@ def build_sys(self, data=None, error=True, mask=None, trace_profile=None, throug """ # Get the detector model - b_matrix, data = self.get_detector_model(data, error, mask, trace_profile, throughput) + b_matrix, data = self.get_detector_model(data, error) # (B_T * B) * f = (data/sig)_T * B # (matrix ) * f = result @@ -819,34 +587,22 @@ def build_sys(self, data=None, error=True, mask=None, trace_profile=None, throug return matrix, result.toarray().squeeze() - def get_detector_model(self, data=None, error=True, mask=None, trace_profile=None, throughput=None): - """Get the linear model of the detector pixel, B.dot(flux) = pixels + + def get_detector_model(self, data, error): + """ + TODO: are mask, trace_profile, throughput ever passed in here? + Get the linear model of the detector pixel, B.dot(flux) = pixels TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. There will be only one matrix multiplication: (P/sig).(w.T.lambda.c_n). + Parameters ---------- - data : (N, M) array_like, optional + data : (N, M) array_like A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error: bool or (N, M) array_like, optional + + error: (N, M) array_like Estimate of the error on each pixel. - If 2-d array, `sig` is the new error estimation map. - It is the same shape as `sig` initiation input. If bool, - whether to apply sigma or not. The method will return - b_n/sigma if True or array_like and b_n if False. If True, - the default object attribute `sig` will be use. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` Returns ------ @@ -854,129 +610,49 @@ def get_detector_model(self, data=None, error=True, mask=None, trace_profile=Non From the linear equation B.dot(flux) = pix_array """ - # Check if inputs are suited for quick mode; - # Quick mode if `t_list` is not specified. - quick = (throughput is None) - - # and if mask doesn't change - quick &= (mask is None) - quick &= (self.w_t_wave_c is not None) # Pre-computed - - # Use data from object as default - if data is None: - data = self.data - else: - # Update data - self.data = data - - # Update mask if given - if mask is not None: - self.update_mask(mask) - - # Take (updated) mask from object - mask = self.mask - - # Get some dimensions infos - n_wavepoints, n_orders = self.n_wavepoints, self.n_orders - - # Update trace_profile maps and throughput values. - if trace_profile is not None: - self.update_trace_profile(trace_profile) - - if throughput is not None: - self.update_throughput(throughput) - - # Calculations + # Check if `w_t_wave_c` is pre-computed + quick = (self.w_t_wave_c is not None) # Build matrix B # Initiate with empty matrix - n_i = (~mask).sum() # n good pixels - b_matrix = csr_matrix((n_i, n_wavepoints)) + n_i = (~self.mask).sum() # n good pixels + b_matrix = csr_matrix((n_i, self.n_wavepoints)) # Sum over orders - for i_order in range(n_orders): + for i_order in range(self.n_orders): # Get sparse pixel mapping matrix. - b_matrix += self.get_pixel_mapping(i_order, error=error, quick=quick) + b_matrix += self.get_pixel_mapping(i_order, error, quick=quick) # Build detector pixels' array - # Fisrt get `error` which have been update` - # when calling `get_pixel_mapping` - error = self.error - # Take only valid pixels and apply `error` on data - data = data[~mask] / error[~mask] + data = data[~self.mask] / error[~self.mask] return b_matrix, csr_matrix(data) - def set_tikho_matrix(self, t_mat=None, t_mat_func=None, fargs=None, fkwargs=None): - """Set the tikhonov matrix attribute. - The matrix can be directly specified as an input, or - it can be built using `t_mat_func` - Parameters - ---------- - t_mat : matrix-like, optional - Tikhonov regularization matrix. scipy.sparse matrix - are recommended. - t_mat_func : callable, optional - Function used to generate `t_mat` if not specified. - Will take `fargs` and `fkwargs`as input. - Use the `atoca_utils.get_tikho_matrix` as default. - fargs : tuple, optional - Arguments passed to `t_mat_func`. Default is `(self.wave_grid, )`. - fkwargs : dict, optional - Keyword arguments passed to `t_mat_func`. Default is - `{'n_derivative': 1, 'd_grid': True, 'estimate': None, 'pwr_law': 0}` - - Returns - ------- - None + @property + def tikho_mat(self): """ - - # Generate the matrix with the function - if t_mat is None: - - # Default function if not specified - if t_mat_func is None: - # Use the `atoca_utils.get_tikho_matrix` - # The default arguments will return the 1rst derivative - # as the tikhonov matrix - t_mat_func = atoca_utils.get_tikho_matrix - - # Default args - if fargs is None: - # The argument for `atoca_utils.get_tikho_matrix` is the wavelength grid - fargs = (self.wave_grid, ) - if fkwargs is None: - # The kwargs for `atoca_utils.get_tikho_matrix` are - # n_derivative = 1, d_grid = True, estimate = None, pwr_law = 0 - fkwargs = {'n_derivative': 1, 'd_grid': True, 'estimate': None, 'pwr_law': 0} - - # Call function - t_mat = t_mat_func(*fargs, **fkwargs) - - # Set attribute - self.tikho_mat = t_mat - - - def get_tikho_matrix(self, **kwargs): - """ - #TODO: why is there also an atoca_utils.get_tikho_matrix function? - Return the Tikhonov matrix. - Generate it with `set_tikho_matrix` method - if not defined yet. If so, all arguments are passed - to `set_tikho_matrix`. The result is saved as an attribute. + Generate it with `atoca_utils.get_tikho_matrix` method if not defined yet. """ + if self._tikho_mat is not None: + return self._tikho_mat - if self.tikho_mat is None: - self.set_tikho_matrix(**kwargs) + fkwargs = {'n_derivative': 1, 'd_grid': True, 'estimate': None, 'pwr_law': 0} + self._tikho_mat = atoca_utils.get_tikho_matrix(self.wave_grid, **fkwargs) + return self._tikho_mat + + + @tikho_mat.setter + def tikho_mat(self, t_mat): + self._tikho_mat = t_mat - return self.tikho_mat def estimate_tikho_factors(self, flux_estimate): - """Estimate an initial guess of the Tikhonov factor. The output factor will + """ + Estimate an initial guess of the Tikhonov factor. The output factor will be used to find the best Tikhonov factor. The flux_estimate is used to generate a factor_guess. The user should construct a grid with this output in log space, e.g. np.logspace(np.log10(flux_estimate)-4, np.log10(flux_estimate)+4, 9). @@ -1001,11 +677,8 @@ def estimate_tikho_factors(self, flux_estimate): # Project the estimate on the wavelength grid estimate_on_grid = flux_estimate(wave_grid) - # Get the tikhonov matrix - tikho_matrix = self.get_tikho_matrix() - # Estimate the norm-2 of the regularization term - reg_estimate = tikho_matrix.dot(estimate_on_grid) + reg_estimate = self.tikho_mat.dot(estimate_on_grid) reg_estimate = np.nansum(np.array(reg_estimate) ** 2) # Estimate of the factor @@ -1015,8 +688,8 @@ def estimate_tikho_factors(self, flux_estimate): return factor_guess - def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, - error=None, mask=None, trace_profile=None, throughput=None): + + def get_tikho_tests(self, factors, data, error): """ Test different factors for Tikhonov regularization. @@ -1024,31 +697,12 @@ def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, ---------- factors : 1D list or array-like Factors to be tested. - tikho : Tikhonov object, optional - Tikhonov regularization object (see regularization.Tikhonov). - If not given, an object will be initiated using the linear system - from `build_sys` method and kwargs will be passed. - tikho_kwargs : - passed to init Tikhonov object. Possible options - are `t_mat` and `grid` - data : (N, M) array_like, optional + + data : (N, M) array_like A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error : (N, M) array_like, optional - Estimate of the error on each pixel` - Same shape as `data`. - Default is the object attribute `sig`. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` + + error : (N, M) array_like + Estimate of the error on each pixel. Same shape as `data`. Returns ------ @@ -1056,13 +710,9 @@ def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, """ # Build the system to solve - b_matrix, pix_array = self.get_detector_model(data, error, mask, trace_profile, throughput) + b_matrix, pix_array = self.get_detector_model(data, error) - if tikho is None: - t_mat = self.get_tikho_matrix() - if tikho_kwargs is None: - tikho_kwargs = {} - tikho = atoca_utils.Tikhonov(b_matrix, pix_array, t_mat, **tikho_kwargs) + tikho = atoca_utils.Tikhonov(b_matrix, pix_array, self.t_mat) # Test all factors tests = tikho.test_factors(factors) @@ -1075,8 +725,10 @@ def get_tikho_tests(self, factors, tikho=None, tikho_kwargs=None, data=None, return tests - def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): + + def best_tikho_factor(self, tests, fit_mode): """ + TODO: why does function with same name exist here and in atoca_utils? Compute the best scale factor for Tikhonov regularization. It is determined by taking the factor giving the lowest reduced chi2 on the detector, the highest curvature of the l-curve or when the improvement @@ -1084,33 +736,27 @@ def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): Parameters ---------- - tests : dictionary, optional + tests : dictionary Results of Tikhonov extraction tests for different factors. Must have the keys "factors" and "-logl". - If not specified, the tests from self.tikho.tests are used. - fit_mode : string, optional + + fit_mode : string Which mode is used to find the best Tikhonov factor. Options are 'all', 'curvature', 'chi2', 'd_chi2'. If 'all' is chosen, the best of the three other options will be selected. - mode_kwargs : dictionary-like - Dictionary of keyword arguments to be passed to TikhoTests.best_tikho_factor(). - Example: mode_kwargs = {'curvature': curvature_kwargs}. - Here, curvature_kwargs is also a dictionary. Returns ------- best_fac : float The best Tikhonov factor. + best_mode : str The mode used to determine the best factor. + results : dict A dictionary holding the factors computed for each mode requested. """ - # Use pre-run tests if not specified - if tests is None: - tests = self.tikho_tests - # TODO Find a way to identify when the solution becomes unstable # and do nnot use these in the search for the best tikhonov factor. # The follwing commented bloc was an attemp to do it, but problems @@ -1138,30 +784,17 @@ def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): # Single mode list_mode = [fit_mode] - # Init the mode_kwargs if None were given - if mode_kwargs is None: - mode_kwargs = dict() - - # Fill the missing value in mode_kwargs - for mode in list_mode: - try: - # Try the mode - mode_kwargs[mode] - except KeyError: - # Init with empty dictionary if it was not given - mode_kwargs[mode] = dict() - # Evaluate best factor with different methods results = dict() for mode in list_mode: - best_fac = tests.best_tikho_factor(mode=mode, **mode_kwargs[mode]) + best_fac = tests.best_tikho_factor(mode=mode) results[mode] = best_fac if fit_mode == 'all': # Choose the best factor. - # In a well behave case, the results should be ordered as 'chi2', 'd_chi2', 'curvature' + # In a well-behaved case, the results should be ordered as 'chi2', 'd_chi2', 'curvature' # and 'd_chi2' will be the best criterion determine the best factor. - # 'chi2' usually overfitting the solution and 'curvature' may oversmooth the solution + # 'chi2' usually overfits the solution and 'curvature' may oversmooth the solution if results['curvature'] <= results['chi2'] or results['d_chi2'] <= results['chi2']: # In this case, 'chi2' is likely to not overfit the solution, so must be favored best_mode = 'chi2' @@ -1184,20 +817,22 @@ def best_tikho_factor(self, tests=None, fit_mode='all', mode_kwargs=None): return best_fac, best_mode, results - def rebuild(self, spectrum=None, i_orders=None, same=False, fill_value=0.0): - """Build current model image of the detector. + + def rebuild(self, spectrum, same=False, fill_value=0.0): + """ + Build current model image of the detector. + Parameters ---------- - spectrum : callable or array-like, optional + spectrum : callable or array-like flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - If not provided, will find via self.__call__(). - i_orders : list[int], optional - Indices of orders to model. Default is all available orders. + same : bool, optional If True, do not recompute the pixel_mapping matrix (b_n) and instead use the most recent pixel_mapping to speed up the computation. Default is False. + fill_value : float or np.nan, optional Pixel value where the detector is masked. Default is 0.0. @@ -1206,73 +841,56 @@ def rebuild(self, spectrum=None, i_orders=None, same=False, fill_value=0.0): array[float] The modeled detector image. """ - - # If no spectrum given compute it. - if spectrum is None: - spectrum = self.__call__() - # If flux is callable, evaluate on the wavelength grid. if callable(spectrum): spectrum = spectrum(self.wave_grid) - # Iterate over all orders by default. - if i_orders is None: - i_orders = range(self.n_orders) - - # Get required class attribute. - mask = self.mask + # Iterate over all orders + i_orders = range(self.n_orders) # Evaluate the detector model. model = np.zeros(self.data_shape) for i_order in i_orders: # Compute the pixel mapping matrix (b_n) for the current order. - pixel_mapping = self.get_pixel_mapping(i_order, error=False, same=same) + pixel_mapping = self.get_pixel_mapping(i_order, error=None, same=same) # Evaluate the model of the current order. - model[~mask] += pixel_mapping.dot(spectrum) + model[~self.mask] += pixel_mapping.dot(spectrum) # Assign masked values - model[mask] = fill_value - + model[self.mask] = fill_value return model - def compute_likelihood(self, spectrum=None, same=False): + + def compute_likelihood(self, spectrum, data, error): """Return the log likelihood associated with a particular spectrum. Parameters ---------- - spectrum : array[float] or callable, optional + spectrum : array[float] or callable Flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - If not given it will be computed by calling self.__call__(). - same : bool, optional - If True, do not recompute the pixel_mapping matrix (b_n) - and instead use the most recent pixel_mapping to speed up the computation. - Default is False. + + data : (N, M) array_like + A 2-D array of real values representing the detector image. + + error : (N, M) array_like + Estimate of the error on each pixel. + Same shape as `data`. Returns ------- array[float] The log-likelihood of the spectrum. """ - - # If no spectrum given compute it. - if spectrum is None: - spectrum = self.__call__() - # Evaluate the model image for the spectrum. - model = self.rebuild(spectrum, same=same) - - # Get data and error attributes. - data = self.data - error = self.error - mask = self.mask + model = self.rebuild(spectrum, same=False) # Compute the log-likelihood for the spectrum. with np.errstate(divide='ignore'): logl = (model - data) / error - return -np.nansum((logl[~mask])**2) + return -np.nansum((logl[~self.mask])**2) @staticmethod @@ -1311,7 +929,7 @@ def _solve_tikho(matrix, result, t_mat, **kwargs): return tikho.solve(**kwargs) - def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): + def __call__(self, data, error, tikhonov=False, factor=None): """ Extract underlying flux on the detector. All parameters are passed to `build_sys` method. @@ -1324,56 +942,38 @@ def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs): Parameters ---------- + data : (N, M) array_like + A 2-D array of real values representing the detector image. + + error : (N, M) array_like + Estimate of the error on each pixel` + Same shape as `data`. + tikhonov : bool, optional Whether to use Tikhonov extraction Default is False. - tikho_kwargs : dictionary or None, optional - Arguments passed to `tikho_solve`. + factor : the Tikhonov factor to use if tikhonov is True - data : (N, M) array_like, optional - A 2-D array of real values representing the detector image. - Default is the object attribute `data`. - error : (N, M) array_like, optional - Estimate of the error on each pixel` - Same shape as `data`. - Default is the object attribute `sig`. - mask : (N, M) array_like boolean, optional - Additional mask for a given exposure. Will be added - to the object general mask. - trace_profile : (N_ord, N, M) list or array of 2-D arrays, optional - A list or array of the spatial profile for each order - on the detector. It has to have the same (N, M) as `data`. - Default is the object attribute `p_list` - throughput : (N_ord [, N_k]) list or array of functions, optional - A list or array of the throughput at each order. - The functions depend on the wavelength - Default is the object attribute `t_list` Returns ----- spectrum (f_k): solution of the linear system """ - # Solve with the specified solver. if tikhonov: # Build the system to solve - b_matrix, pix_array = self.get_detector_model(**kwargs) + b_matrix, pix_array = self.get_detector_model(data, error) if factor is None: msg = "Please specify tikhonov `factor`." log.critical(msg) raise ValueError(msg) - t_mat = self.get_tikho_matrix() - - if tikho_kwargs is None: - tikho_kwargs = {} - - spectrum = self._solve_tikho(b_matrix, pix_array, t_mat, factor=factor, **tikho_kwargs) + spectrum = self._solve_tikho(b_matrix, pix_array, self.tikho_mat, factor=factor) else: # Build the system to solve - matrix, result = self.build_sys(**kwargs) + matrix, result = self.build_sys(data, error) # Only solve for valid range `i_grid` (on the detector). # It will be a singular matrix otherwise. @@ -1402,16 +1002,13 @@ def _get_lo_hi(self, grid, i_order): log.debug('Computing lowest and highest indices of wave_grid.') - # Get needed attributes - mask = self.mask - # ... order dependent attributes attrs = ['wave_p', 'wave_m', 'mask_ord'] wave_p, wave_m, mask_ord = self.get_attributes(*attrs, i_order=i_order) # Compute only for valid pixels - wave_p = wave_p[~mask] - wave_m = wave_m[~mask] + wave_p = wave_p[~self.mask] + wave_m = wave_m[~self.mask] # Find lower (lo) index in the pixel lo = np.searchsorted(grid, wave_m, side='right') @@ -1420,7 +1017,7 @@ def _get_lo_hi(self, grid, i_order): hi = np.searchsorted(grid, wave_p) - 1 # Set invalid pixels for this order to lo=-1 and hi=-2 - ma = mask_ord[~mask] + ma = mask_ord[~self.mask] lo[ma], hi[ma] = -1, -2 return lo, hi @@ -1428,6 +1025,7 @@ def _get_lo_hi(self, grid, i_order): def get_mask_wave(self, i_order): """Generate mask bounded by limits of wavelength grid + Parameters ---------- i_order : int diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index b9de2c7e3e..47a9b8f863 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -205,13 +205,11 @@ def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tra # Init extraction without convolution kernel (so extract the spectrum at order 1 resolution) ref_file_args = [wave_maps[0]], [spat_pros[0]], [thrpts[0]], [np.array([1.])] - kwargs = {'wave_grid': wave_grid, - 'orders': [1], - 'mask_trace_profile': [mask]} - engine = ExtractionEngine(*ref_file_args, **kwargs) + kwargs = {'orders': [1],} + engine = ExtractionEngine(*ref_file_args, wave_grid, [mask], **kwargs) # Extract estimate - spec_estimate = engine.__call__(data=scidata_bkg, error=scierr) + spec_estimate = engine.__call__(scidata_bkg, scierr) # Interpolate idx = np.isfinite(spec_estimate) @@ -538,10 +536,10 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, scimask = scimask | ~(scierr > 0) # Define mask based on box aperture (we want to model each contaminated pixels that will be extracted) - mask_trace_profile = [~(box_weights[order] > 0) for order in order_list] + mask_trace_profile = [(~(box_weights[order] > 0)) | (refmask) for order in order_list] # Define mask of pixel to model (all pixels inside box aperture) - global_mask = np.all(mask_trace_profile, axis=0) | refmask + global_mask = np.all(mask_trace_profile, axis=0).astype(bool) # Rough estimate of the underlying flux # Note: estim_flux func is not strictly necessary and factors could be a simple logspace - @@ -566,7 +564,6 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid, mask_trace_profile=mask_trace_profile, - global_mask=scimask, threshold=threshold, c_kwargs=c_kwargs) @@ -608,10 +605,10 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, log.info('Using a Tikhonov factor of {}'.format(tikfac)) # Run the extract method of the Engine. - f_k = engine.__call__(data=scidata_bkg, error=scierr, tikhonov=True, factor=tikfac) + f_k = engine.__call__(scidata_bkg, scierr, tikhonov=True, factor=tikfac) # Compute the log-likelihood of the best fit. - logl = engine.compute_likelihood(f_k, same=False) + logl = engine.compute_likelihood(f_k, scidata_bkg, scierr) log.info('Optimal solution has a log-likelihood of {}'.format(logl)) @@ -777,11 +774,12 @@ def throughput(wavelength): # Initialize the engine engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid, + mask_trace_profile=[mask_fit], orders=[order], - mask_trace_profile=[mask_fit]) + ) # Extract estimate - spec_estimate = engine.__call__(data=data_order, error=err_order) + spec_estimate = engine.__call__(data_order, err_order) # Interpolate idx = np.isfinite(spec_estimate) @@ -797,8 +795,8 @@ def throughput(wavelength): # Initialize the Engine. engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid_os, - orders=[order], - mask_trace_profile=[mask_fit]) + mask_trace_profile=[mask_fit], + orders=[order],) # Find the tikhonov factor. # Initial pass with tikfac_range. @@ -819,7 +817,7 @@ def throughput(wavelength): all_tests = append_tiktests(all_tests, tiktests) # Run the extract method of the Engine. - f_k_final = engine.__call__(data=data_order, error=err_order, tikhonov=True, factor=tikfac) + f_k_final = engine.__call__(data_order, err_order, tikhonov=True, factor=tikfac) # Save binned spectra in a list of SingleSpecModels for optional output spec_list = [] @@ -841,8 +839,8 @@ def throughput(wavelength): # Initialize the Engine. engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid_os, - orders=[order], - mask_trace_profile=[mask_rebuild]) + mask_trace_profile=[mask_rebuild], + orders=[order],) # Project on detector and save in dictionary model = engine.rebuild(f_k_final, fill_value=np.nan) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py new file mode 100644 index 0000000000..b556e83265 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -0,0 +1,35 @@ +import pytest +import numpy as np +from jwst.extract_1d.soss_extract import atoca + + +@pytest.fixture(scope="module") +def wave_map(): + shp = (100, 30) + wave_ord1 = np.linspace(2.8, 0.8, shp[0]) + + pass + +@pytest.fixture(scope="module") +def trace_profile(wave_map): + pass + +@pytest.fixture(scope="module") +def throughput(wave_map): + pass + +@pytest.fixture(scope="module") +def kernels(): + pass + +@pytest.fixture(scope="module") +def wave_grid(): + pass + +@pytest.fixture(scope="module") +def mask_trace_profile(wave_map): + pass + + +def test_extraction_engine(): + pass From 324b3ff453faf2c29ce380ff5e96a69a7bd8ebf1 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 1 Nov 2024 12:42:38 -0400 Subject: [PATCH 09/35] made functions private, removed some unused optionals --- jwst/extract_1d/soss_extract/atoca.py | 117 +++---- jwst/extract_1d/soss_extract/atoca_utils.py | 328 ++++-------------- jwst/extract_1d/soss_extract/pastasoss.py | 4 +- .../soss_extract/soss_boxextract.py | 44 +-- jwst/extract_1d/soss_extract/soss_extract.py | 245 ++++++------- jwst/extract_1d/soss_extract/soss_syscor.py | 38 +- .../soss_extract/tests/test_atoca_utils.py | 24 +- 7 files changed, 300 insertions(+), 500 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index b9780c4c47..6bf4015b3d 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -364,7 +364,6 @@ def _get_i_bnds(self): i_bnds_new = [] for bounds, i_bnds in zip(wave_bounds, self.i_bounds): - print(bounds, i_bnds) a = np.min(np.where(self.wave_grid >= bounds[0])[0]) b = np.max(np.where(self.wave_grid <= bounds[1])[0]) + 1 @@ -383,6 +382,7 @@ def update_i_bnds(self): # Get old and new boundaries. i_bnds_old = self.i_bounds i_bnds_new = self._get_i_bnds() + print(i_bnds_old, i_bnds_new) for i_order in range(self.n_orders): @@ -462,7 +462,7 @@ def grid_from_map(self, i_order=0): return wave_grid, icol - def get_pixel_mapping(self, i_order, error=None, same=False, quick=False): + def get_pixel_mapping(self, i_order, error=None, quick=False): """ Compute the matrix `b_n = (P/sig).w.T.lambda.c_n` , where `P` is the spatial profile matrix (diag), @@ -485,10 +485,6 @@ def get_pixel_mapping(self, i_order, error=None, same=False, quick=False): If None, the error is set to 1, which means the method will return b_n instead of b_n/sigma. Default is None. - same: bool, optional - Do not recompute b_n. Take the last b_n computed. - Useful to speed up code. Default is False. - quick: bool, optional If True, only perform one matrix multiplication instead of the whole system: (P/sig).(w.T.lambda.c_n) @@ -499,60 +495,51 @@ def get_pixel_mapping(self, i_order, error=None, same=False, quick=False): Sparse matrix of b_n coefficients """ - # Force to compute if b_n never computed. - if self.pixel_mapping[i_order] is None: - same = False - - # Take the last b_n computed if nothing changes - if same: - pixel_mapping = self.pixel_mapping[i_order] - - else: - # Special treatment for error map - # Can be bool or array. - if error is None: - # Sigma will have no effect - error = np.ones(self.data_shape) + # Special treatment for error map + # Can be bool or array. + if error is None: + # Sigma will have no effect + error = np.ones(self.data_shape) - # Get needed attributes ... - attrs = ['wave_grid', 'mask'] - wave_grid, mask = self.get_attributes(*attrs) + # Get needed attributes ... + attrs = ['wave_grid', 'mask'] + wave_grid, mask = self.get_attributes(*attrs) - # ... order dependent attributes - attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] - trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) + # ... order dependent attributes + attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] + trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) - # Keep only valid pixels (P and sig are still 2-D) - # And apply directly 1/sig here (quicker) - trace_profile_n = trace_profile_n[~mask] / error[~mask] + # Keep only valid pixels (P and sig are still 2-D) + # And apply directly 1/sig here (quicker) + trace_profile_n = trace_profile_n[~mask] / error[~mask] - # Compute b_n - # Quick mode if only `p_n` or `sig` has changed - if quick: - # Get pre-computed (right) part of the equation - right = self.w_t_wave_c[i_order] + # Compute b_n + # Quick mode if only `p_n` or `sig` has changed + if quick: + # Get pre-computed (right) part of the equation + right = self.w_t_wave_c[i_order] - # Apply new p_n - pixel_mapping = diags(trace_profile_n).dot(right) + # Apply new p_n + pixel_mapping = diags(trace_profile_n).dot(right) - else: - # First (T * lam) for the convolve axis (n_k_c) - product = (throughput_n * wave_grid)[slice(*i_bnds)] + else: + # First (T * lam) for the convolve axis (n_k_c) + product = (throughput_n * wave_grid)[slice(*i_bnds)] - # then convolution - product = diags(product).dot(kernel_n) + # then convolution + product = diags(product).dot(kernel_n) - # then weights - product = weights_n.dot(product) + # then weights + product = weights_n.dot(product) - # Save this product for quick mode - self._set_w_t_wave_c(i_order, product) + # Save this product for quick mode + self._set_w_t_wave_c(i_order, product) - # Then spatial profile - pixel_mapping = diags(trace_profile_n).dot(product) + # Then spatial profile + pixel_mapping = diags(trace_profile_n).dot(product) - # Save new pixel mapping matrix. - self.pixel_mapping[i_order] = pixel_mapping + # Save new pixel mapping matrix. + self.pixel_mapping[i_order] = pixel_mapping return pixel_mapping @@ -706,7 +693,8 @@ def get_tikho_tests(self, factors, data, error): Returns ------ - dictionary of the tests results + tests : dict + dictionary of the test results """ # Build the system to solve @@ -720,9 +708,6 @@ def get_tikho_tests(self, factors, data, error): # Save also grid tests["grid"] = self.wave_grid - # Save as attribute - self.tikho_tests = tests - return tests @@ -749,12 +734,6 @@ def best_tikho_factor(self, tests, fit_mode): ------- best_fac : float The best Tikhonov factor. - - best_mode : str - The mode used to determine the best factor. - - results : dict - A dictionary holding the factors computed for each mode requested. """ # TODO Find a way to identify when the solution becomes unstable @@ -812,13 +791,12 @@ def best_tikho_factor(self, tests, fit_mode): # Get the factor of the chosen mode best_fac = results[best_mode] - log.debug(f'Mode chosen to find regularization factor is {best_mode}') - return best_fac, best_mode, results + return best_fac - def rebuild(self, spectrum, same=False, fill_value=0.0): + def rebuild(self, spectrum, fill_value=0.0): """ Build current model image of the detector. @@ -828,11 +806,6 @@ def rebuild(self, spectrum, same=False, fill_value=0.0): flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - same : bool, optional - If True, do not recompute the pixel_mapping matrix (b_n) - and instead use the most recent pixel_mapping to speed up the computation. - Default is False. - fill_value : float or np.nan, optional Pixel value where the detector is masked. Default is 0.0. @@ -853,7 +826,7 @@ def rebuild(self, spectrum, same=False, fill_value=0.0): for i_order in i_orders: # Compute the pixel mapping matrix (b_n) for the current order. - pixel_mapping = self.get_pixel_mapping(i_order, error=None, same=same) + pixel_mapping = self.get_pixel_mapping(i_order, error=None) # Evaluate the model of the current order. model[~self.mask] += pixel_mapping.dot(spectrum) @@ -885,7 +858,7 @@ def compute_likelihood(self, spectrum, data, error): The log-likelihood of the spectrum. """ # Evaluate the model image for the spectrum. - model = self.rebuild(spectrum, same=False) + model = self.rebuild(spectrum) # Compute the log-likelihood for the spectrum. with np.errstate(divide='ignore'): @@ -1020,6 +993,12 @@ def _get_lo_hi(self, grid, i_order): ma = mask_ord[~self.mask] lo[ma], hi[ma] = -1, -2 + print("grid", np.min(grid), np.max(grid), grid.size) + # grid 0.8515140811891634 2.204728219634399 1408 on PR branch + # grid 0.8480011181377183 2.2057149524019297 1413 on main + # the grid could be the problem! + # try reverting all changes to atoca_utils.grid_from_map and dependencies + return lo, hi diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 811d7adacb..e31449c236 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -113,7 +113,7 @@ def get_wave_p_or_m(wave_map, dispersion_axis=1): The wavelength upper and lower boundaries of each pixel, given the central value. """ # Get wavelength boundaries of each pixels - wave_left, wave_right = get_wv_map_bounds(wave_map, dispersion_axis=dispersion_axis) + wave_left, wave_right = _get_wv_map_bounds(wave_map, dispersion_axis=dispersion_axis) # The outputs depend on the direction of the spectral axis. invalid = (wave_map == 0) @@ -129,7 +129,7 @@ def get_wave_p_or_m(wave_map, dispersion_axis=1): return wave_plus, wave_minus -def get_wv_map_bounds(wave_map, dispersion_axis=1): +def _get_wv_map_bounds(wave_map, dispersion_axis=1): """ Compute boundaries of a pixel map, given the pixel central value. Parameters ---------- @@ -238,7 +238,7 @@ def oversample_grid(wave_grid, n_os): return np.unique(wave_grid_os) -def extrapolate_grid(wave_grid, wave_range, poly_ord=1): +def _extrapolate_grid(wave_grid, wave_range, poly_ord=1): """Extrapolate the given 1D wavelength grid to cover a given range of values by fitting the derivative with a polynomial of a given order and using it to compute subsequent values at both ends of the grid. @@ -403,7 +403,7 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): grid = _grid_from_map(wave_map, trace_profile) # Extrapolate values out of the wv_map if needed - grid = extrapolate_grid(grid, wave_range, poly_ord=1) + grid = _extrapolate_grid(grid, wave_range, poly_ord=1) # Constrain grid to be within wave_range grid = grid[grid>=wave_range[0]] @@ -419,86 +419,7 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): return oversample_grid(grid, n_os=n_os) -def get_soss_grid(wave_maps, trace_profiles, wave_min=0.55, wave_max=3.0, n_os=2): - """Create a wavelength grid specific to NIRISS SOSS mode observations. - Assumes 2 orders are given, use grid_from_map if only one order is needed. - - Parameters - ---------- - wave_maps : array[float] - Array containing the pixel wavelengths for order 1 and 2. - - trace_profiles : array[float] - Array containing the spatial profiles for order 1 and 2. - - wave_min : float - Minimum wavelength the output grid should cover. - - wave_max : float - Maximum wavelength the output grid should cover. - - n_os : int or list[int] - Oversampling of the grid compared to the pixel sampling. Can be - specified for each order if a list is given. If a single value is given - it will be used for all orders. Default is 2 for all orders. - - Returns - ------- - wave_grid_soss : array[float] - Wavelength grid optimized for extracting SOSS spectra across - order 1 and order 2. - """ - if np.ndim(n_os) == 0: - n_os = [n_os, n_os] - elif len(n_os) != 2: - msg = (f"n_os must be an integer or a 2 element list or array of " - f"integers, got {n_os} instead") - log.critical(msg) - raise ValueError(msg) - if (wave_maps.shape[0] != 2) or (trace_profiles.shape[0] != 2): - msg = 'wave_maps and trace_profiles must have shape (2, detector_x, detector_y).' - log.critical(msg) - raise ValueError(msg) - - wave_maps[wave_maps <= 0] = np.nan - trace_profiles[trace_profiles < 0] = 0 - if np.any(np.isnan(wave_maps)): - msg = 'Encountered NaN values in wave_maps, ignoring them.' - log.warning(msg) - - # Generate a wavelength range for each order. - # Order 1 covers the reddest part of the spectrum, - # so apply wave_max on order 1 and vice versa for order 2. - # TODO: This doesn't know about throughput, so it's possible these min/max wavelengths - # would be chosen such that the algorithm attemps to extract a wavelength from an order - # that has very low (but non-zero) throughput at that wavelength - - # current test data has min(wave_map) != min(wave_grid) because of throughput! - # how to fix this? - - wave_min_o1 = np.maximum(np.nanmin(wave_maps[0]), wave_min) - wave_max_o2 = np.minimum(np.nanmax(wave_maps[1]), wave_max) - - # Use grid_from_map to construct separate oversampled grids for both orders. - wave_grid_o1 = grid_from_map(wave_maps[0], trace_profiles[0], - wave_range=[wave_min_o1, wave_max], - n_os=n_os[0]) - wave_grid_o2 = grid_from_map(wave_maps[1], trace_profiles[1], - wave_range=[wave_min, wave_max_o2], - n_os=n_os[1]) - - # Keep only wavelengths in order 1 that aren't covered by order 2. - mask = wave_grid_o1 > wave_grid_o2.max() - wave_grid_o1 = wave_grid_o1[mask] - - # Combine the order 1 and order 2 grids. - wave_grid_soss = np.concatenate([wave_grid_o1, wave_grid_o2]) - - # Sort values (and keep only unique). - return np.unique(wave_grid_soss) - - -def _trim_grids(all_grids, grid_range=None): +def _trim_grids(all_grids, grid_range): """ Remove all parts of the grids that are not in range or that are already covered by grids with higher priority, i.e. preceding in the list. @@ -506,16 +427,14 @@ def _trim_grids(all_grids, grid_range=None): grids_trimmed = [] for grid in all_grids: # Remove parts of the grid that are not in the wavelength range - if grid_range is not None: - # Find where the limit values fall on the grid - i_min = np.searchsorted(grid, grid_range[0], side='right') - i_max = np.searchsorted(grid, grid_range[1], side='left') - # Make sure it is a valid value and take one grid point past the limit - # since the oversampling could squeeze some nodes near the limits - i_min = np.max([i_min - 1, 0]) - i_max = np.min([i_max, len(grid) - 1]) - # Trim the grid - grid = grid[i_min:i_max + 1] + i_min = np.searchsorted(grid, grid_range[0], side='right') + i_max = np.searchsorted(grid, grid_range[1], side='left') + # Make sure it is a valid value and take one grid point past the limit + # since the oversampling could squeeze some nodes near the limits + i_min = np.max([i_min - 1, 0]) + i_max = np.min([i_max, len(grid) - 1]) + # Trim the grid + grid = grid[i_min:i_max + 1] # Remove parts of the grid that are already covered if len(grids_trimmed) > 0: @@ -551,11 +470,11 @@ def _trim_grids(all_grids, grid_range=None): return grids_trimmed -def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, - max_iter=10, rtol=10e-6, tol=0.0, max_total_size=1000000): +def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, + max_iter=10, rtol=10e-6, max_total_size=1000000): """ TODO: can this be a class? e.g., class AdaptiveGrid? - TODO: why aren't any of the same helper functions used here as in get_soss_grid? + TODO: why aren't any of the same helper functions used here as in _get_soss_grid? q: why are there multiple grids passed in here in the first place Return an irregular oversampled grid needed to reach a @@ -568,32 +487,38 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, Parameters ---------- all_grid : list[array] - List of grid (arrays) to pass to adapt_grid, in order of importance. + List of grid (arrays) to pass to _adapt_grid, in order of importance. + all_estimate : list[callable] List of function (callable) to estimate the precision needed to oversample the grid. Must match the corresponding `grid` in `all_grid`. + + grid_range : list[float] + Wavelength range the new grid should cover. + max_iter : int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given by delta_grid / 2^max_iter. Needs to be greater then zero. Default is 10. + rtol : float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - tol : float, optional - The desired absolute tolerance. Default is 0 to prioritize `rtol`. + max_total_size : int, optional maximum size of the output grid. Default is 1 000 000. + Returns ------- os_grid : 1D array Oversampled combined grid which minimizes the integration error based on Romberg's method """ - # Save parameters for adapt_grid - kwargs = dict(max_iter=max_iter, rtol=rtol, tol=tol) + # Save parameters for _adapt_grid + kwargs = dict(max_iter=max_iter, rtol=rtol) # Remove unneeded parts of the grids - all_grids = _trim_grids(all_grids, grid_range=grid_range) + all_grids = _trim_grids(all_grids, grid_range) # Save native size of each grids (use later to adjust max_grid_size) all_sizes = [len(grid) for grid in all_grids] @@ -615,7 +540,7 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range=None, kwargs['max_grid_size'] = np.max([max_grid_size, all_sizes[i_grid]]) # Oversample the grid based on tolerance required - grid, is_converged = adapt_grid(grid, estimate, **kwargs) + grid, is_converged = _adapt_grid(grid, estimate, **kwargs) # Update grid sizes all_sizes[i_grid] = grid.size @@ -741,7 +666,7 @@ def _difftrap(fct, intervals, numtraps): return ordsum -def estim_integration_err(grid, fct): +def _estim_integration_err(grid, fct): """Estimate the integration error on each intervals of the grid using 1rst order Romberg integration. @@ -786,7 +711,7 @@ def estim_integration_err(grid, fct): return err, rel_err -def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): +def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): """Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. The grid is built by subdividing iteratively each intervals that @@ -800,26 +725,31 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): Grid for integration. Each sections of this grid are treated as separate integrals. So if grid has length N; N-1 integrals are optimized. + fct: callable Function to be integrated. Must be a function of `grid` + max_iter: int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given by delta_grid / 2^max_iter. Needs to be greater then zero. Default is 10. + rtol: float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - tol: float, optional - The desired absolute tolerance. Default is 0 to prioritize `rtol`. + max_grid_size: int, optional maximum size of the output grid. Default is None, so no constraint. + Returns ------- os_grid : 1D array Oversampled grid which minimizes the integration error based on Romberg's method + convergence_flag: bool Whether the estimated tolerance was reach everywhere or not. + See Also -------- scipy.integrate.quadrature.romberg @@ -839,10 +769,10 @@ def adapt_grid(grid, fct, max_iter=10, rtol=10e-6, tol=0.0, max_grid_size=None): for _ in range(max_iter): # Estimate error using Romberg integration - err, rel_err = estim_integration_err(grid, fct) + err, rel_err = _estim_integration_err(grid, fct) # Check where precision is reached - converged = (err < tol) | (rel_err < rtol) + converged = rel_err < rtol is_converged = converged.all() # Check if max grid size was reached @@ -1114,50 +1044,7 @@ def __call__(self, wave, wave_c): return webbker -def gaussians(x, x0, sig, amp=None): - """ - TODO: can this be replaced by something in scipy or numpy, - e.g., scipy.signal.windows.gaussian? - Gaussian function - - Parameters - ---------- - x : array[float] - Array of points over which gaussian to be defined. - x0 : float - Center of the gaussian. - sig : float - Standard deviation of the gaussian. - amp : float - Value of the gaussian at the center. - - Returns - ------- - values : array[float] - Array of gaussian values for input x. - """ - if amp is None: - amp = 1. / np.sqrt(2. * np.pi * sig**2.) - return amp * np.exp(-0.5 * ((x - x0) / sig) ** 2.) - - -def fwhm2sigma(fwhm): - """Convert a full width half max to a standard deviation, assuming a gaussian - - Parameters - ---------- - fwhm : float - Full-width half-max of a gaussian. - - Returns - ------- - sigma : float - Standard deviation of a gaussian. - """ - return fwhm / np.sqrt(8. * np.log(2.)) - - -def to_2d(kernel, grid_range): +def _to_2d(kernel, grid_range): """Build a 2d kernel array with a constant 1D kernel (input) Parameters @@ -1251,7 +1138,7 @@ def _get_wings(fct, grid, h_len, i_a, i_b): return left, right -def trpz_weight(grid, length, shape, i_a, i_b): +def _trpz_weight(grid, length, shape, i_a, i_b): """ TODO: add to some integration class? @@ -1303,7 +1190,7 @@ def trpz_weight(grid, length, shape, i_a, i_b): return out -def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): +def _fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): """ TODO: can't scipy do this? @@ -1373,7 +1260,7 @@ def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): out = np.vstack([left, out, right]) # Weights due to integration (from the convolution) - weights = trpz_weight(grid, length, out.shape, i_a, i_b) + weights = _trpz_weight(grid, length, out.shape, i_a, i_b) elif (length % 2) == 1: # length needs to be odd # Generate a 2D array of the grid iteratively until @@ -1391,10 +1278,10 @@ def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): out = np.vstack([left, out, right]) # Weights due to integration (from the convolution) - weights = trpz_weight(grid, length, out.shape, i_a, i_b) + weights = _trpz_weight(grid, length, out.shape, i_a, i_b) else: - msg = "`length` provided to `fct_to_array` must be odd." + msg = "`length` provided to `_fct_to_array` must be odd." log.critical(msg) raise ValueError(msg) @@ -1402,7 +1289,7 @@ def fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): return kern_array -def cut_ker(ker, n_out=None, thresh=None): +def _cut_ker(ker, n_out=None, thresh=None): """Apply a cut on the convolution matrix boundaries. Parameters @@ -1477,7 +1364,7 @@ def cut_ker(ker, n_out=None, thresh=None): return ker -def sparse_c(ker, n_k, i_zero=0): +def _sparse_c(ker, n_k, i_zero=0): """Convert a convolution kernel in compact form (N_ker, N_k_convolved) to sparse form (N_k_convolved, N_k) @@ -1502,7 +1389,7 @@ def sparse_c(ker, n_k, i_zero=0): # Algorithm works for odd kernel grid if n_ker % 2 != 1: - err_msg = "Length of the convolution kernel given to sparse_c should be odd." + err_msg = "Length of the convolution kernel given to _sparse_c should be odd." log.critical(err_msg) raise ValueError(err_msg) @@ -1599,9 +1486,9 @@ def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, # Generate a 2D kernel depending on the input if callable(kernel): - kernel = fct_to_array(kernel, grid, [a, b], **kwargs) + kernel = _fct_to_array(kernel, grid, [a, b], **kwargs) elif kernel.ndim == 1: - kernel = to_2d(kernel, [a, b]) + kernel = _to_2d(kernel, [a, b]) if kernel.ndim != 2: msg = ("Input kernel to get_c_matrix must be callable or" @@ -1615,81 +1502,16 @@ def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, kernel = kernel / np.nansum(kernel, axis=0) # Apply cut for kernel at boundaries - kernel = cut_ker(kernel, n_out, thresh_out) + kernel = _cut_ker(kernel, n_out, thresh_out) if sparse: # Convert to a sparse matrix. - kernel = sparse_c(kernel, len(grid), a) + kernel = _sparse_c(kernel, len(grid), a) return kernel -class NyquistKer: - """ - TODO: look into whether custom Gaussian function is needed, or if - something like scipy.ndimage.gaussian_filter1d could be used. - - Define a gaussian convolution kernel at the nyquist - sampling. For a given point on the grid x_i, the kernel - is given by a gaussian with - FWHM = n_sampling * (dx_(i-1) + dx_i) / 2. - The FWHM is computed for each elements of the grid except - the extremities (not defined). We can then generate FWHM as - a function of the grid and interpolate/extrapolate to get - the kernel as a function of its position relative to the grid. - """ - - def __init__(self, grid, n_sampling=2, bounds_error=False, - fill_value="extrapolate", **kwargs): - """Parameters - ---------- - grid : array[float] - Grid used to define the kernels - n_sampling : int, optional - Sampling of the grid. - bounds_error : bool - Argument for `interp1d` to get FWHM as a function of the grid. - fill_value : str - Argument for `interp1d` to choose fill method to get FWHM. - """ - - # Delta grid - d_grid = np.diff(grid) - - # The full width half max is n_sampling - # times the mean of d_grid - fwhm = (d_grid[:-1] + d_grid[1:]) / 2 - fwhm *= n_sampling - - # What we really want is sigma, not FWHM - sig = fwhm2sigma(fwhm) - - # Now put sigma as a function of the grid - sig = interp1d(grid[1:-1], sig, bounds_error=bounds_error, - fill_value=fill_value, **kwargs) - - self.fct_sig = sig - - def __call__(self, x, x0): - """Parameters - ---------- - x : array[float] - position where the kernel is evaluated - x0 : array[float] - position of the kernel center for each x. - - Returns - ------- - Value of the gaussian kernel for each set of (x, x0) - """ - - # Get the sigma of each gaussian - sig = self.fct_sig(x0) - - return gaussians(x, x0, sig) - - -def finite_diff(x): +def _finite_diff(x): """Returns the finite difference matrix operator based on x. Parameters @@ -1712,7 +1534,7 @@ def finite_diff(x): return diff_matrix -def finite_second_d(grid): +def _finite_second_d(grid): """Returns the second derivative operator based on grid Parameters @@ -1729,7 +1551,7 @@ def finite_second_d(grid): """ # Finite difference operator - d_matrix = finite_diff(grid) + d_matrix = _finite_diff(grid) # Delta lambda d_grid = d_matrix.dot(grid) @@ -1738,13 +1560,13 @@ def finite_second_d(grid): first_d = diags(1. / d_grid).dot(d_matrix) # Second derivative operator - second_d = finite_diff(grid[:-1]).dot(first_d) + second_d = _finite_diff(grid[:-1]).dot(first_d) # don't forget the delta lambda return diags(1. / d_grid[:-1]).dot(second_d) -def finite_first_d(grid): +def _finite_first_d(grid): """Returns the first derivative operator based on grid Parameters @@ -1761,7 +1583,7 @@ def finite_first_d(grid): """ # Finite difference operator - d_matrix = finite_diff(grid) + d_matrix = _finite_diff(grid) # Delta lambda d_grid = d_matrix.dot(grid) @@ -1807,9 +1629,9 @@ def get_tikho_matrix(grid, n_derivative=1, d_grid=True, estimate=None, pwr_law=0 input_grid = np.arange(len(grid)) if n_derivative == 1: - t_mat = finite_first_d(input_grid) + t_mat = _finite_first_d(input_grid) elif n_derivative == 2: - t_mat = finite_second_d(input_grid) + t_mat = _finite_second_d(input_grid) else: msg = "`n_derivative` must be 1 or 2." log.critical(msg) @@ -1850,7 +1672,7 @@ def get_tikho_matrix(grid, n_derivative=1, d_grid=True, estimate=None, pwr_law=0 return t_mat -def curvature_finite(factors, log_reg2, log_chi2): +def _curvature_finite(factors, log_reg2, log_chi2): """Compute the curvature in log space using finite differences Parameters @@ -1874,8 +1696,8 @@ def curvature_finite(factors, log_reg2, log_chi2): factors, log_chi2, log_reg2 = factors[idx], log_chi2[idx], log_reg2[idx] # Get first and second derivatives - chi2_deriv = get_finite_derivatives(factors, log_chi2) - reg2_deriv = get_finite_derivatives(factors, log_reg2) + chi2_deriv = _get_finite_derivatives(factors, log_chi2) + reg2_deriv = _get_finite_derivatives(factors, log_reg2) # Compute the curvature according to Hansen 2001 # @@ -1894,7 +1716,7 @@ def curvature_finite(factors, log_reg2, log_chi2): return factors, curv -def get_finite_derivatives(x_array, y_array): +def _get_finite_derivatives(x_array, y_array): """ Compute first and second finite derivatives Parameters ---------- @@ -2093,19 +1915,19 @@ def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): return best_val -def soft_l1(z): +def _soft_l1(z): return 2 * ((1 + z)**0.5 - 1) -def cauchy(z): +def _cauchy(z): return np.log(1 + z) -def linear(z): +def _linear(z): return z -LOSS_FUNCTIONS = {'soft_l1': soft_l1, 'cauchy': cauchy, 'linear': linear} +LOSS_FUNCTIONS = {'soft_l1': _soft_l1, 'cauchy': _cauchy, 'linear': _linear} class TikhoTests(dict): @@ -2160,10 +1982,10 @@ def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): # Save the chi2 self[chi2_type] except KeyError: - self[chi2_type] = self.compute_chi2(loss=loss) + self[chi2_type] = self._compute_chi2(loss=loss) - def compute_chi2(self, tests=None, n_points=None, loss='linear'): + def _compute_chi2(self, tests=None, loss='linear'): """ TODO: is there a scipy builtin that does this? Calculates the reduced chi squared statistic @@ -2246,7 +2068,7 @@ def compute_curvature(self, tests=None, key=None): # Get the norm-2 of the regularisation term reg2 = np.nansum(tests['reg'] ** 2, axis=-1) - factors, curv = curvature_finite(tests['factors'], + factors, curv = _curvature_finite(tests['factors'], np.log10(self[key]), np.log10(reg2)) @@ -2254,7 +2076,11 @@ def compute_curvature(self, tests=None, key=None): def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, mode='curvature', key=None, thresh=None): - """Compute the best scale factor for Tikhonov regularisation. + """ + TODO: why is there a function with identical name in atoca.py ExtractionEngine? + this one is called by the other one... + + Compute the best scale factor for Tikhonov regularisation. It is determined by taking the factor giving the highest logL on the detector or the highest curvature of the l-curve, depending on the chosen mode. diff --git a/jwst/extract_1d/soss_extract/pastasoss.py b/jwst/extract_1d/soss_extract/pastasoss.py index 6fec75c778..332bd1a287 100644 --- a/jwst/extract_1d/soss_extract/pastasoss.py +++ b/jwst/extract_1d/soss_extract/pastasoss.py @@ -552,6 +552,4 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec if spectraces: return np.array([wavemap_1, wavemap_2]), np.array([spectrace_1, spectrace_2]) - - else: - return np.array([wavemap_1, wavemap_2]) + return np.array([wavemap_1, wavemap_2]) diff --git a/jwst/extract_1d/soss_extract/soss_boxextract.py b/jwst/extract_1d/soss_extract/soss_boxextract.py index f499b21d7b..67b28dff31 100644 --- a/jwst/extract_1d/soss_extract/soss_boxextract.py +++ b/jwst/extract_1d/soss_extract/soss_boxextract.py @@ -5,8 +5,9 @@ log.setLevel(logging.DEBUG) -def get_box_weights(centroid, n_pix, shape, cols=None): - """ Return the weights of a box aperture given the centroid and the width of +def get_box_weights(centroid, n_pix, shape, cols): + """ + Return the weights of a box aperture given the centroid and the width of the box in pixels. All pixels will have the same weights except at the ends of the box aperture. @@ -14,10 +15,13 @@ def get_box_weights(centroid, n_pix, shape, cols=None): ---------- centroid : array[float] Position of the centroid (in rows). Same shape as `cols` + n_pix : float Width of the extraction box in pixels. + shape : Tuple(int, int) Shape of the output image. (n_row, n_column) + cols : array[int] Column indices of good columns. Used if the centroid is defined for specific columns or a sub-range of columns. @@ -27,12 +31,7 @@ def get_box_weights(centroid, n_pix, shape, cols=None): weights : array[float] An array of pixel weights to use with the box extraction. """ - - nrows, ncols = shape - - # Use all columns if not specified - if cols is None: - cols = np.arange(ncols) + nrows, _ = shape # Row centers of all pixels. rows = np.indices((nrows, len(cols)))[0] @@ -59,38 +58,37 @@ def get_box_weights(centroid, n_pix, shape, cols=None): return out -def box_extract(scidata, scierr, scimask, box_weights, cols=None): - """ Perform a box extraction. +def box_extract(scidata, scierr, scimask, box_weights): + """ + Perform a box extraction. Parameters ---------- scidata : array[float] 2d array of science data with shape (n_row, n_columns) + scierr : array[float] 2d array of uncertainty map with same shape as scidata + scimask : array[bool] 2d boolean array of masked pixels with same shape as scidata + box_weights : array[float] 2d array of pre-computed weights for box extraction, with same shape as scidata - cols : array[int] - 1d integer array of column numbers to extract Returns ------- cols : array[int] Indices of extracted columns + flux : array[float] The flux in each column + flux_var : array[float] The variance of the flux in each column """ - - nrows, ncols = scidata.shape - - # Use all columns if not specified - if cols is None: - cols = np.arange(ncols) + cols = np.arange(scidata.shape[1]) # Keep only required columns and make a copy. data = scidata[:, cols].copy() @@ -137,6 +135,9 @@ def box_extract(scidata, scierr, scimask, box_weights, cols=None): def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): """ + TODO: how similar is this to other places where we interpolate errors? + Could this be replaced with some algorithm involving smoothing of the error map? + Function to estimate pixel error empirically using the corresponding error of the nearest pixel value (`data`). Intended to be used in a box extraction when the bad pixels are modeled. @@ -145,12 +146,16 @@ def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): ---------- err : 2d array[float] Uncertainty map of the pixels. + data : 2d array[float] Pixel values. + pix_to_estim : 2d array[bool] Map of the pixels where the uncertainty needs to be estimated. + valid_pix : 2d array[bool] Map of valid pixels to be used to find the error empirically. + Returns ------- err_filled : 2d array[float] @@ -161,9 +166,6 @@ def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): err_valid = err[valid_pix] data_valid = data[valid_pix] - # - # Use np.searchsorted for efficiency - # # Need to sort the arrays used to find similar values idx_sort = np.argsort(data_valid) err_valid = err_valid[idx_sort] diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 47a9b8f863..1f1e8d4f37 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -22,9 +22,13 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) +ORDER2_SHORT_CUTOFF = 0.58 + def get_ref_file_args(ref_files): - """Prepare the reference files for the extraction engine. + """ + Prepare the reference files for the extraction engine. + Parameters ---------- ref_files : dict @@ -45,9 +49,9 @@ def get_ref_file_args(ref_files): else: do_padding = False - (wavemap_o1, wavemap_o2), (spectrace_o1, spectrace_o2) = \ + (wavemap_o1, wavemap_o2) = \ get_soss_wavemaps(pastasoss_ref, pwcpos=ref_files['pwcpos'], subarray=ref_files['subarray'], - padding=do_padding, padsize=pad, spectraces=True) + padding=do_padding, padsize=pad, spectraces=False) # The spectral profiles for order 1 and 2. specprofile_ref = ref_files['specprofile'] @@ -96,20 +100,21 @@ def get_ref_file_args(ref_files): wave_maps = [wavemap_o1, wavemap_o2] centroid = dict() for wv_map, order in zip(wave_maps, [1, 2]): - # Needs the same number of columns as the detector. Put zeros where not define. + # Needs the same number of columns as the detector. Put zeros where not defined. wv_cent = np.zeros((1, wv_map.shape[1])) # Get central wavelength as a function of columns - col, _, wv = get_trace_1d(ref_files, order) + col, _, wv = _get_trace_1d(ref_files, order) wv_cent[:, col] = wv # Set invalid values to zero idx_invalid = ~np.isfinite(wv_cent) wv_cent[idx_invalid] = 0.0 centroid[order] = wv_cent + # Get kernels kernels_o1 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[1], ovs, n_pix) kernels_o2 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[2], ovs, n_pix) - # Temporary(?) fix to make sure that the kernels can cover the wavelength maps + # Make sure that the kernels cover the wavelength maps speckernel_wv_range = [np.min(speckernel_ref.wavelengths), np.max(speckernel_ref.wavelengths)] valid_wavemap = (speckernel_wv_range[0] <= wavemap_o1) & (wavemap_o1 <= speckernel_wv_range[1]) wavemap_o1 = np.where(valid_wavemap, wavemap_o1, 0.) @@ -120,13 +125,15 @@ def get_ref_file_args(ref_files): [throughput_o1, throughput_o2], [kernels_o1, kernels_o2] -def get_trace_1d(ref_files, order): +def _get_trace_1d(ref_files, order): """Get the x, y, wavelength of the trace after applying the transform. + Parameters ---------- ref_files : dict A dictionary of the reference file DataModels, along with values for subarray and pwcpos, i.e. the pupil wheel position. + order : int The spectral order for which to return the trace parameters. @@ -168,23 +175,29 @@ def get_trace_1d(ref_files, order): return xtrace, ytrace, wavetrace -def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile, threshold=1e-4): +def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile, threshold=1e-4): """ Parameters ---------- scidata_bkg : array A single background subtracted NIRISS SOSS detector image. + scierr : array The uncertainties corresponding to the detector image. + scimask : array Pixel mask to apply to the detector image. + ref_file_args : tuple A tuple of reference file arguments constructed by get_ref_file_args(). + mask_trace_profile: array[bool] Mask determining the aperture used for extraction. Set to False where the pixel should be extracted. + threshold : float, optional: The pixels with an aperture[order 2] > `threshold` are considered contaminated and will be masked. Default is 1e-4. + Returns ------- func @@ -194,11 +207,8 @@ def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tra # Unpack ref_file arguments wave_maps, spat_pros, thrpts, _ = ref_file_args - # Oversampling of 1 to make sure the solution will be stable - n_os = 1 - # Define wavelength grid based on order 1 only (so first index) - wave_grid = grid_from_map(wave_maps[0], spat_pros[0], n_os=n_os) + wave_grid = grid_from_map(wave_maps[0], spat_pros[0], n_os=1) # Mask parts contaminated by order 2 based on its spatial profile mask = ((spat_pros[1] >= threshold) | mask_trace_profile | scimask) @@ -218,23 +228,29 @@ def estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tra return estimate_spl -def get_native_grid_from_trace(ref_files, spectral_order): +def _get_native_grid_from_trace(ref_files, spectral_order): """ Make a 1d-grid of the pixels boundary and ready for ATOCA ExtractionEngine, based on the wavelength solution. + Parameters ---------- ref_files: dict A dictionary of the reference file DataModels. + spectral_order: int The spectral order for which to return the trace parameters. + Returns ------- - Grid of the pixels boundaries at the native sampling (1d array) + wave : + Grid of the pixels boundaries at the native sampling (1d array) + col : + The column number of the pixel """ - # From wavelenght solution - col, _, wave = get_trace_1d(ref_files, spectral_order) + # From wavelength solution + col, _, wave = _get_trace_1d(ref_files, spectral_order) # Keep only valid solution ... idx_valid = np.isfinite(wave) @@ -255,7 +271,7 @@ def get_native_grid_from_trace(ref_files, spectral_order): return wave, col -def get_grid_from_trace(ref_files, spectral_order, n_os=1): +def _get_grid_from_trace(ref_files, spectral_order, n_os): """ TODO: is this partially or fully redundant with atoca_utils.grid_from_map? Make a 1d-grid of the pixels boundary and ready for ATOCA ExtractionEngine, @@ -264,14 +280,20 @@ def get_grid_from_trace(ref_files, spectral_order, n_os=1): ---------- ref_files: dict A dictionary of the reference file DataModels. + spectral_order: int The spectral order for which to return the trace parameters. + + n_os: int or array + The oversampling factor of the wavelength grid used when solving for + the uncontaminated flux. + Returns ------- Grid of the pixels boundaries at the native sampling (1d array) """ - wave, _ = get_native_grid_from_trace(ref_files, spectral_order) + wave, _ = _get_native_grid_from_trace(ref_files, spectral_order) # Use pixel boundaries instead of the center values wv_upper_bnd, wv_lower_bnd = get_wave_p_or_m(wave[None, :]) @@ -283,13 +305,13 @@ def get_grid_from_trace(ref_files, spectral_order, n_os=1): wave_grid = np.append(wv_lower_bnd, wv_upper_bnd[-1]) # Oversample as needed - wave_grid = oversample_grid(wave_grid, n_os=n_os) - - return wave_grid + return oversample_grid(wave_grid, n_os=n_os) -def make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os, wv_range=None): - ''' Create the grid use for the simultaneous extraction of order 1 and 2. +def _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os): + ''' + TODO: add docstring + Create the grid to use for the simultaneous extraction of order 1 and 2. The grid is made by: 1) requiring that it satisfies the oversampling n_os 2) trying to reach the specified tolerance for the spectral range shared between order 1 and 2 @@ -302,7 +324,7 @@ def make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os, wv spectral_orders = [2, 1] grids_ord = dict() for sp_ord in spectral_orders: - grids_ord[sp_ord] = get_grid_from_trace(ref_files, sp_ord, n_os=n_os) + grids_ord[sp_ord] = _get_grid_from_trace(ref_files, sp_ord, n_os=n_os) # Build the list of grids given to make_combined_grid. # It must be ordered in increasing priority. @@ -317,29 +339,24 @@ def make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os, wv # And make grid list all_grids = [grids_ord[2][is_shared], grids_ord[1], grids_ord[2][~is_shared]] - # Set wavelength range if not given - if wv_range is None: - # Cut order 2 at 0.77 (not smaller than that) - # because there is no contamination there. Can be extracted afterward. - # In the red, no cut. - wv_range = [0.77, np.max(grids_ord[1])] + # Cut order 2 at 0.77 (not smaller than that) + # because there is no contamination there. Can be extracted afterward. + # In the red, no cut. + wv_range = [0.77, np.max(grids_ord[1])] # Finally, build the list of corresponding estimates. # The estimate for the overlapping part is the order 1 estimate. # There is no estimate yet for the blue part of order 2, so give a flat spectrum. def flat_fct(wv): return np.ones_like(wv) - all_estimates = [estimate, estimate, flat_fct] # Generate the combined grid - kwargs = dict(rtol=rtol, max_total_size=max_grid_size, max_iter=30, grid_range=wv_range) - combined_grid = make_combined_adaptive_grid(all_grids, all_estimates, **kwargs) + kwargs = dict(rtol=rtol, max_total_size=max_grid_size, max_iter=30) + return make_combined_adaptive_grid(all_grids, all_estimates, wv_range, **kwargs) - return combined_grid - -def append_tiktests(test_a, test_b): +def _append_tiktests(test_a, test_b): out = dict() @@ -349,7 +366,7 @@ def append_tiktests(test_a, test_b): return out -def populate_tikho_attr(spec, tiktests, idx, sp_ord): +def _populate_tikho_attr(spec, tiktests, idx, sp_ord): spec.spectral_order = sp_ord spec.meta.soss_extract1d.type = 'TEST' @@ -360,10 +377,8 @@ def populate_tikho_attr(spec, tiktests, idx, sp_ord): spec.meta.soss_extract1d.factor = tiktests['factors'][idx] spec.int_num = 0 - return - -def f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord=0): +def _f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord): # Make sure the input is not modified ref_file_args = ref_file_args.copy() @@ -424,9 +439,9 @@ def _build_tracemodel_order(engine, ref_file_args, f_k, i_order, mask, ref_files tracemodel_ord = model.rebuild(flux_order, fill_value=np.nan) # Build 1d spectrum integrated over pixels - pixel_wave_grid, valid_cols = get_native_grid_from_trace(ref_files, sp_ord) - spec_ord = f_to_spec(flux_order, grid_order, ref_file_order, pixel_wave_grid, - np.all(mask, axis=0)[valid_cols], sp_ord=sp_ord) + pixel_wave_grid, valid_cols = _get_native_grid_from_trace(ref_files, sp_ord) + spec_ord = _f_to_spec(flux_order, grid_order, ref_file_order, pixel_wave_grid, + np.all(mask, axis=0)[valid_cols], sp_ord) return tracemodel_ord, spec_ord @@ -446,7 +461,7 @@ def _build_null_spec_table(wave_grid): Null SpecModel. Flux values are NaN, DQ flags are 1, but note that DQ gets overwritten at end of run_extract1d """ - wave_grid_cut = wave_grid[wave_grid > 0.58] # same cutoff applied for valid data + wave_grid_cut = wave_grid[wave_grid > ORDER2_SHORT_CUTOFF] spec = datamodels.SpecModel() spec.spectral_order = 2 spec.meta.soss_extract1d.type = 'OBSERVATION' @@ -460,7 +475,7 @@ def _build_null_spec_table(wave_grid): return spec -def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, +def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, tikfac=None, threshold=1e-4, n_os=2, wave_grid=None, estimate=None, rtol=1e-3, max_grid_size=1000000): """Perform the spectral extraction on a single image. @@ -469,38 +484,50 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, ---------- scidata_bkg : array[float] A single background subtracted NIRISS SOSS detector image. + scierr : array[float] The uncertainties corresponding to the detector image. + scimask : array[bool] Pixel mask to apply to detector image. + refmask : array[bool] Pixels that should never be reconstructed e.g. the reference pixels. + ref_files : dict A dictionary of the reference file DataModels, along with values for subarray and pwcpos, i.e. the pupil wheel position. + box_weights : dict A dictionary of the weights (for each order) used in the box extraction. The weights for each order are 2d arrays with the same size as the detector. + tikfac : float, optional The Tikhonov regularization factor used when solving for the uncontaminated flux. If not specified, the optimal Tikhonov factor is calculated. + threshold : float The threshold value for using pixels based on the spectral profile. Default value is 1e-4. + n_os : int, optional The oversampling factor of the wavelength grid used when solving for the uncontaminated flux. If not specified, defaults to 2. + wave_grid : str or SossWaveGridModel or None Filename of reference file or SossWaveGridModel containing the wavelength grid used by ATOCA to model each pixel valid pixel of the detector. If not given, the grid is determined based on an estimate of the flux (estimate), the relative tolerance (rtol) required on each pixel model and the maximum grid size (max_grid_size). + estimate : UnivariateSpline or None Estimate of the target flux as a function of wavelength in microns. + rtol : float The relative tolerance needed on a pixel model. It is used to determine the sampling of the soss_wave_grid when not directly given. Default is 1e-3. + max_grid_size : int Maximum grid size allowed. It is used when soss_wave_grid is not directly to make sure the computation time or the memory used stays reasonable. @@ -510,20 +537,21 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, ------- tracemodels : dict Dictionary of the modeled detector images for each order. + tikfac : float Optimal Tikhonov factor used in extraction + logl : float Log likelihood value associated with the Tikhonov factor selected. + wave_grid : 1d array - Same as wave_grid input + Same as wave_grid input. TODO: this isn't true if input wave_grid is None, update docstring + spec_list : list of SpecModel List of the underlying spectra for each integration and order. The tikhonov tests are also included. """ - # Init list of atoca 1d spectra - spec_list = [] - # Generate list of orders to simulate from pastasoss trace list order_list = [] for trace in ref_files['pastasoss'].traces: @@ -546,13 +574,13 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # dq mask caused issues here and this may need a try/except wrap. # Dev suggested np.logspace(-19, -10, 10) if (tikfac is None or wave_grid is None) and estimate is None: - estimate = estim_flux_first_order(scidata_bkg, scierr, scimask, + estimate = _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile[0]) # Generate grid based on estimate if not given if wave_grid is None: log.info(f'wave_grid not given: generating grid based on rtol={rtol}') - wave_grid = make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os) + wave_grid = _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os) log.debug(f'wave_grid covering from {wave_grid.min()} to {wave_grid.max()}') else: log.info('Using previously computed or user specified wavelength grid.') @@ -567,6 +595,7 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, threshold=threshold, c_kwargs=c_kwargs) + spec_list = [] if tikfac is None: log.info('Solving for the optimal Tikhonov factor.') @@ -577,15 +606,15 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, log_guess = np.log10(guess_factor) factors = np.logspace(log_guess - 4, log_guess + 4, 10) all_tests = engine.get_tikho_tests(factors, data=scidata_bkg, error=scierr) - tikfac, mode, _ = engine.best_tikho_factor(tests=all_tests, fit_mode='all') + tikfac= engine.best_tikho_factor(tests=all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) factors = np.logspace(tikfac - 2, tikfac + 2, 20) tiktests = engine.get_tikho_tests(factors, data=scidata_bkg, error=scierr) - tikfac, mode, _ = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') + tikfac = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') # Add all theses tests to previous ones - all_tests = append_tiktests(all_tests, tiktests) + all_tests = _append_tiktests(all_tests, tiktests) # Save spectra in a list of SingleSpecModels for optional output save_tiktests = True @@ -594,7 +623,7 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, f_k = all_tests['solution'][idx, :] args = (engine, ref_file_args, f_k, i_order, global_mask, ref_files) _, spec_ord = _build_tracemodel_order(*args) - populate_tikho_attr(spec_ord, all_tests, idx, i_order + 1) + _populate_tikho_attr(spec_ord, all_tests, idx, i_order + 1) spec_ord.meta.soss_extract1d.color_range = 'RED' # Add the result to spec_list @@ -652,7 +681,7 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # mask_fit |= already_modeled # Build 1d spectrum integrated over pixels - pixel_wave_grid, valid_cols = get_native_grid_from_trace(ref_files, order) + pixel_wave_grid, valid_cols = _get_native_grid_from_trace(ref_files, order) # Hardcode wavelength highest boundary as well. # Must overlap with lower limit in make_decontamination_grid @@ -664,10 +693,10 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Model the remaining part of order 2 with atoca try: - model, spec_ord = model_single_order(scidata_bkg, scierr, ref_file_order, + model, spec_ord = _model_single_order(scidata_bkg, scierr, ref_file_order, mask_fit, global_mask, order, - pixel_wave_grid, valid_cols, save_tiktests, - tikfac_log_range=tikfac_log_range) + pixel_wave_grid, valid_cols, + tikfac_log_range, save_tiktests=save_tiktests) except MaskOverlapError: log.error('Not enough unmasked pixels to model the remaining part of order 2.' @@ -691,7 +720,7 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, return tracemodels, tikfac, logl, wave_grid, spec_list -def compute_box_weights(ref_files, shape, width=40.): +def _compute_box_weights(ref_files, shape, width): # Generate list of orders from pastasoss trace list order_list = [] @@ -709,13 +738,13 @@ def compute_box_weights(ref_files, shape, width=40.): log.debug(f'Compute box weights for {order}.') # Define the box aperture - xtrace, ytrace, wavelengths[order] = get_trace_1d(ref_files, order_integer) + xtrace, ytrace, wavelengths[order] = _get_trace_1d(ref_files, order_integer) box_weights[order] = get_box_weights(ytrace, width, shape, cols=xtrace) return box_weights, wavelengths -def decontaminate_image(scidata_bkg, tracemodels, subarray): +def _decontaminate_image(scidata_bkg, tracemodels, subarray): """Perform decontamination of the image based on the trace models""" # Which orders to extract. if subarray == 'SUBSTRIP96': @@ -753,8 +782,9 @@ def decontaminate_image(scidata_bkg, tracemodels, subarray): # TODO Add docstring -def model_single_order(data_order, err_order, ref_file_args, mask_fit, - mask_rebuild, order, wave_grid, valid_cols, save_tiktests=False, tikfac_log_range=None): +def _model_single_order(data_order, err_order, ref_file_args, mask_fit, + mask_rebuild, order, wave_grid, valid_cols, + tikfac_log_range, save_tiktests=False): # The throughput and kernel is not needed here; set them so they have no effect on the extraction. def throughput(wavelength): @@ -765,32 +795,9 @@ def throughput(wavelength): ref_file_args[2] = [throughput] ref_file_args[3] = [kernel] - # ########################### - # First, generate an estimate - # (only if the initial guess of tikhonov factor range is not given) - # ########################### - - if tikfac_log_range is None: - # Initialize the engine - engine = ExtractionEngine(*ref_file_args, - wave_grid=wave_grid, - mask_trace_profile=[mask_fit], - orders=[order], - ) - - # Extract estimate - spec_estimate = engine.__call__(data_order, err_order) - - # Interpolate - idx = np.isfinite(spec_estimate) - estimate_spl = UnivariateSpline(wave_grid[idx], spec_estimate[idx], k=3, s=0, ext=0) - - # ################################################## - # Second, do the extraction to get the best estimate - # ################################################## # Define wavelength grid with oversampling of 3 (should be enough) wave_grid_os = oversample_grid(wave_grid, n_os=3) - wave_grid_os = wave_grid_os[wave_grid_os > 0.58] + wave_grid_os = wave_grid_os[wave_grid_os > ORDER2_SHORT_CUTOFF] # Initialize the Engine. engine = ExtractionEngine(*ref_file_args, @@ -800,21 +807,16 @@ def throughput(wavelength): # Find the tikhonov factor. # Initial pass with tikfac_range. - if tikfac_log_range is None: - guess_factor = engine.estimate_tikho_factors(estimate_spl) - log_guess = np.log10(guess_factor) - factors = np.log_range(log_guess - 2, log_guess + 8, 10) - else: - factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1] + 8, 10) + factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1] + 8, 10) all_tests = engine.get_tikho_tests(factors, data=data_order, error=err_order) - tikfac, mode, _ = engine.best_tikho_factor(tests=all_tests, fit_mode='all') + tikfac = engine.best_tikho_factor(tests=all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) factors = np.logspace(tikfac - 2, tikfac + 2, 20) tiktests = engine.get_tikho_tests(factors, data=data_order, error=err_order) - tikfac, mode, _ = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') - all_tests = append_tiktests(all_tests, tiktests) + tikfac = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') + all_tests = _append_tiktests(all_tests, tiktests) # Run the extract method of the Engine. f_k_final = engine.__call__(data_order, err_order, tikhonov=True, factor=tikfac) @@ -826,28 +828,24 @@ def throughput(wavelength): f_k = all_tests['solution'][idx, :] # Build 1d spectrum integrated over pixels - spec_ord = f_to_spec(f_k, wave_grid_os, ref_file_args, wave_grid, - np.all(mask_rebuild, axis=0)[valid_cols], sp_ord=order) - populate_tikho_attr(spec_ord, all_tests, idx, order) + spec_ord = _f_to_spec(f_k, wave_grid_os, ref_file_args, wave_grid, + np.all(mask_rebuild, axis=0)[valid_cols], order) + _populate_tikho_attr(spec_ord, all_tests, idx, order) # Add the result to spec_list spec_list.append(spec_ord) - # ########################################## - # Third, rebuild trace, including bad pixels - # ########################################## - # Initialize the Engine. + + # Rebuild trace, including bad pixels engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid_os, mask_trace_profile=[mask_rebuild], orders=[order],) - - # Project on detector and save in dictionary model = engine.rebuild(f_k_final, fill_value=np.nan) # Build 1d spectrum integrated over pixels - spec_ord = f_to_spec(f_k_final, wave_grid_os, ref_file_args, wave_grid, - np.all(mask_rebuild, axis=0)[valid_cols], sp_ord=order) + spec_ord = _f_to_spec(f_k_final, wave_grid_os, ref_file_args, wave_grid, + np.all(mask_rebuild, axis=0)[valid_cols], order) spec_ord.meta.soss_extract1d.factor = tikfac spec_ord.meta.soss_extract1d.type = 'OBSERVATION' @@ -858,27 +856,35 @@ def throughput(wavelength): # Remove bad pixels that are not modeled for pixel number -def extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='model', tracemodels=None): - """Perform the box-extraction on the image, while using the trace model to +def _extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='model', tracemodels=None): + """ + Perform the box-extraction on the image, while using the trace model to correct for contamination. + Parameters ---------- decontaminated_data : array[float] A single backround subtracted NIRISS SOSS detector image. + scierr : array[float] The uncertainties corresponding to the detector image. + scimask : array[float] Pixel mask to apply to the detector image. + box_weights : dict A dictionary of the weights (for each order) used in the box extraction. The weights for each order are 2d arrays with the same size as the detector. + bad_pix : str How to handle the bad pixels. Options are 'masking' and 'model'. 'masking' will simply mask the bad pixels, such that the number of pixels in each column in the box extraction will not be constant, while the 'model' option uses `tracemodels` to replace the bad pixels. + tracemodels : dict Dictionary of the modeled detector images for each order. + Returns ------- fluxes, fluxerrs, npixels : dict @@ -1002,10 +1008,10 @@ def run_extract1d(input_model, pastasoss_ref_name, if wave_grid_in is not None: log.info(f'Loading wavelength grid from {wave_grid_in}.') wave_grid = datamodels.SossWaveGridModel(wave_grid_in).wavegrid - # Make sure it as the correct precision + # Make sure it has the correct precision wave_grid = wave_grid.astype('float64') else: - # wave_grid will be estimated later in the first call of `model_image` + # wave_grid will be estimated later in the first call of `_model_image` log.info('Wavelength grid was not specified. Setting `wave_grid` to None.') wave_grid = None @@ -1086,15 +1092,16 @@ def run_extract1d(input_model, pastasoss_ref_name, if soss_kwargs['subtract_background']: log.info('Applying background subtraction.') bkg_mask = make_background_mask(scidata, width=40) - scidata_bkg, col_bkg, npix_bkg = soss_background(scidata, scimask, bkg_mask=bkg_mask) + scidata_bkg, col_bkg = soss_background(scidata, scimask, bkg_mask) else: log.info('Skip background subtraction.') scidata_bkg = scidata col_bkg = np.zeros(scidata.shape[1]) # Pre-compute the weights for box extraction (used in modeling and extraction) - args = (ref_files, scidata_bkg.shape) - box_weights, wavelengths = compute_box_weights(*args, width=soss_kwargs['width']) + box_weights, wavelengths = _compute_box_weights( + ref_files, scidata_bkg.shape, width=soss_kwargs['width'] + ) # Model the traces based on optics filter configuration (CLEAR or F277W) if soss_filter == 'CLEAR' and generate_model: @@ -1109,8 +1116,8 @@ def run_extract1d(input_model, pastasoss_ref_name, kwargs['wave_grid'] = wave_grid kwargs['threshold'] = soss_kwargs['threshold'] - result = model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, **kwargs) - tracemodels, soss_kwargs['tikfac'], logl, wave_grid, spec_list = result + result = _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, **kwargs) + tracemodels, soss_kwargs['tikfac'], _, wave_grid, spec_list = result # Add atoca spectra to multispec for output for spec in spec_list: @@ -1131,7 +1138,7 @@ def run_extract1d(input_model, pastasoss_ref_name, spec_list = None # Decontaminate the data using trace models (if tracemodels not empty) - data_to_extract = decontaminate_image(scidata_bkg, tracemodels, subarray) + data_to_extract = _decontaminate_image(scidata_bkg, tracemodels, subarray) if soss_kwargs['bad_pix'] == 'model': # Generate new trace models for each individual decontaminated orders @@ -1144,7 +1151,7 @@ def run_extract1d(input_model, pastasoss_ref_name, kwargs = dict() kwargs['bad_pix'] = soss_kwargs['bad_pix'] kwargs['tracemodels'] = bad_pix_models - result = extract_image(data_to_extract, scierr, scimask, box_weights, **kwargs) + result = _extract_image(data_to_extract, scierr, scimask, box_weights, **kwargs) fluxes, fluxerrs, npixels = result # Save trace models for output reference diff --git a/jwst/extract_1d/soss_extract/soss_syscor.py b/jwst/extract_1d/soss_extract/soss_syscor.py index eaeda83b31..f04a8a4e17 100644 --- a/jwst/extract_1d/soss_extract/soss_syscor.py +++ b/jwst/extract_1d/soss_extract/soss_syscor.py @@ -6,49 +6,40 @@ log.setLevel(logging.DEBUG) -def soss_background(scidata, scimask, bkg_mask=None): +def soss_background(scidata, scimask, bkg_mask): """Compute a columnwise background for a SOSS observation. Parameters ---------- scidata : array[float] The image of the SOSS trace. + scimask : array[bool] Boolean mask of pixels to be excluded. + bkg_mask : array[bool] Boolean mask of pixels to be excluded because they are in - the trace, typically constructed with make_profile_mask. + the trace, typically constructed with make_background_mask. Returns ------- scidata_bkg : array[float] Background-subtracted image + col_bkg : array[float] Column-wise background values - npix_bkg : array[float] - Number of pixels used to calculate each column value in col_bkg """ # Check the validity of the input. data_shape = scidata.shape - if scimask.shape != data_shape: - msg = 'scidata and scimask must have the same shape.' + if (scimask.shape != data_shape) or (bkg_mask.shape != data_shape): + msg = 'scidata, scimask, and bkg_mask must all have the same shape.' log.critical(msg) raise ValueError(msg) - if bkg_mask is not None: - if bkg_mask.shape != data_shape: - msg = 'scidata and bkg_mask must have the same shape.' - log.critical(msg) - raise ValueError(msg) - # Combine the masks and create a masked array. - if bkg_mask is not None: - mask = scimask | bkg_mask - else: - mask = scimask - + mask = scimask | bkg_mask scidata_masked = np.ma.array(scidata, mask=mask) # Mask additional pixels using sigma-clipping. @@ -58,15 +49,14 @@ def soss_background(scidata, scimask, bkg_mask=None): # Compute the mean for each column and record the number of pixels used. col_bkg = scidata_clipped.mean(axis=0) col_bkg = np.where(np.all(scidata_clipped.mask, axis=0), 0., col_bkg) - npix_bkg = (~scidata_clipped.mask).sum(axis=0) # Background subtract the science data. scidata_bkg = scidata - col_bkg - return scidata_bkg, col_bkg, npix_bkg + return scidata_bkg, col_bkg -def make_background_mask(deepstack, width=28): +def make_background_mask(deepstack, width): """Build a mask of the pixels considered to contain the majority of the flux, and should therefore not be used to compute the background. @@ -75,6 +65,7 @@ def make_background_mask(deepstack, width=28): deepstack : array[float] Deep image of the trace constructed by combining individual integrations of the observation. + width : int Width, in pixels, of the trace to exclude with the mask (i.e. width/256 for a SUBSTRIP256 observation). @@ -84,11 +75,10 @@ def make_background_mask(deepstack, width=28): bkg_mask : array[bool] Pixel mask in the trace based on the deepstack or non-finite in the image. - :rtype: array[bool] """ # Get the dimensions of the input image. - nrows, ncols = np.shape(deepstack) + nrows, _ = np.shape(deepstack) # Set the appropriate quantile for masking based on the subarray size. if nrows == 96: # SUBSTRIP96 @@ -108,6 +98,4 @@ def make_background_mask(deepstack, width=28): # Mask pixels above the threshold value. with np.errstate(invalid='ignore'): - bkg_mask = (deepstack > threshold) | ~np.isfinite(deepstack) - - return bkg_mask + return (deepstack > threshold) | ~np.isfinite(deepstack) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index f0b96a33cd..e6f468612e 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -71,7 +71,7 @@ def test_get_wv_map_bounds(wave_map, dispersion_axis): wave_flip = wave_map.T else: wave_flip = wave_map - wave_top, wave_bottom = au.get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) + wave_top, wave_bottom = au._get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) # flip the results back so we can re-use the same tests if dispersion_axis == 0: @@ -95,7 +95,7 @@ def test_get_wv_map_bounds(wave_map, dispersion_axis): # test bad input error raises with pytest.raises(ValueError): - au.get_wv_map_bounds(wave_flip, dispersion_axis=2) + au._get_wv_map_bounds(wave_flip, dispersion_axis=2) @@ -184,9 +184,9 @@ def test_oversample_irregular(os_factor): @pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) -def test_extrapolate_grid(wave_range): +def test__extrapolate_grid(wave_range): - extrapolated = au.extrapolate_grid(WAVELENGTHS, wave_range, 1) + extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) assert extrapolated.max() > wave_range[1] assert extrapolated.min() < wave_range[0] @@ -201,14 +201,14 @@ def test_extrapolate_catch_failed_converge(): # give wavelengths some non-linearity wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 with pytest.raises(RuntimeError): - au.extrapolate_grid(WAVELENGTHS, wave_range, 1) + au._extrapolate_grid(WAVELENGTHS, wave_range, 1) def test_extrapolate_bad_inputs(): with pytest.raises(ValueError): - au.extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) + au._extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) with pytest.raises(ValueError): - au.extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) + au._extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) def test_grid_from_map(wave_map, trace_profile): @@ -238,7 +238,7 @@ def test_grid_from_map(wave_map, trace_profile): @pytest.mark.parametrize("n_os", ([4,1], 1)) -def test_get_soss_grid(n_os, wave_map, trace_profile, wave_map_o2, trace_profile_o2): +def test__get_soss_grid(n_os, wave_map, trace_profile, wave_map_o2, trace_profile_o2): """ wave_map has min, max wavelength of 1.5, 4.0 but throughput makes this 1.7, 4.2 wave_map_o2 has min, max wavelength of 0.5, 3.0 but throughput makes this 0.9, 3.4 @@ -250,7 +250,7 @@ def test_get_soss_grid(n_os, wave_map, trace_profile, wave_map_o2, trace_profile wave_max = 4.0 wave_maps = np.array([wave_map, wave_map_o2]) trace_profiles = np.array([trace_profile, trace_profile_o2]) - wave_grid = au.get_soss_grid(wave_maps, trace_profiles, wave_min, wave_max, n_os) + wave_grid = au._get_soss_grid(wave_maps, trace_profiles, wave_min, wave_max, n_os) delta_lower = wave_grid[1]-wave_grid[0] delta_upper = wave_grid[-1]-wave_grid[-2] @@ -284,17 +284,17 @@ def test_get_soss_grid(n_os, wave_map, trace_profile, wave_map_o2, trace_profile assert expected_ratio == actual_ratio -def test_get_soss_grid_bad_inputs(wave_map, trace_profile): +def test__get_soss_grid_bad_inputs(wave_map, trace_profile): with pytest.raises(ValueError): # test bad input shapes - au.get_soss_grid(wave_map, trace_profile, 0.5, 0.9, 1) + au._get_soss_grid(wave_map, trace_profile, 0.5, 0.9, 1) wave_maps = np.array([wave_map, wave_map]) trace_profiles = np.array([trace_profile, trace_profile]) with pytest.raises(ValueError): # test bad n_os shape - au.get_soss_grid(wave_maps, trace_profiles, 0.5, 0.9, [1,1,1]) + au._get_soss_grid(wave_maps, trace_profiles, 0.5, 0.9, [1,1,1]) From 333d8175497c2ef3411465d765d742bbf2a336a4 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 6 Nov 2024 16:47:00 -0500 Subject: [PATCH 10/35] added fixes and unit tests to WebbKernel --- jwst/extract_1d/soss_extract/atoca.py | 2 +- jwst/extract_1d/soss_extract/atoca_utils.py | 284 ++++++++---------- jwst/extract_1d/soss_extract/soss_extract.py | 15 +- .../soss_extract/tests/test_atoca_utils.py | 262 ++++++++++++---- 4 files changed, 344 insertions(+), 219 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 6bf4015b3d..94f8acb5cb 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -993,7 +993,7 @@ def _get_lo_hi(self, grid, i_order): ma = mask_ord[~self.mask] lo[ma], hi[ma] = -1, -2 - print("grid", np.min(grid), np.max(grid), grid.size) + # print("grid", np.min(grid), np.max(grid), grid.size) # grid 0.8515140811891634 2.204728219634399 1408 on PR branch # grid 0.8480011181377183 2.2057149524019297 1413 on main # the grid could be the problem! diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index e31449c236..85b2ca52b6 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -420,7 +420,8 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): def _trim_grids(all_grids, grid_range): - """ Remove all parts of the grids that are not in range + """ + Remove all parts of the grids that are not in range or that are already covered by grids with higher priority, i.e. preceding in the list. """ @@ -444,24 +445,19 @@ def _trim_grids(all_grids, grid_range): is_below = grid < np.min(conca_grid) is_above = grid > np.max(conca_grid) - # Do nothing yet if it surrounds the previous grid - if is_below.any() and is_above.any(): - msg = 'Grid surrounds another grid, better to split in 2 parts.' - log.warning(msg) - # Remove values already covered, but keep one # index past the limit - elif is_below.any(): + if is_below.any(): idx = np.max(np.nonzero(is_below)) idx = np.min([idx + 1, len(grid) - 1]) grid = grid[:idx + 1] - elif is_above.any(): + if is_above.any(): idx = np.min(np.nonzero(is_above)) idx = np.max([idx - 1, 0]) grid = grid[idx:] # If all is covered, no need to do it again, so empty grid. - else: + if not is_below.any() and not is_above.any(): grid = np.array([]) # Save trimmed grid @@ -470,11 +466,10 @@ def _trim_grids(all_grids, grid_range): return grids_trimmed -def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, +def make_combined_adaptive_grid(all_grids, all_estimates, grid_range, max_iter=10, rtol=10e-6, max_total_size=1000000): """ TODO: can this be a class? e.g., class AdaptiveGrid? - TODO: why aren't any of the same helper functions used here as in _get_soss_grid? q: why are there multiple grids passed in here in the first place Return an irregular oversampled grid needed to reach a @@ -486,10 +481,10 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, Parameters ---------- - all_grid : list[array] + all_grids : list[array] List of grid (arrays) to pass to _adapt_grid, in order of importance. - all_estimate : list[callable] + all_estimates : list[callable] List of function (callable) to estimate the precision needed to oversample the grid. Must match the corresponding `grid` in `all_grid`. @@ -499,7 +494,7 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, max_iter : int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given - by delta_grid / 2^max_iter. Needs to be greater then zero. + by delta_grid / 2^max_iter. Needs to be greater than zero. Default is 10. rtol : float, optional @@ -514,8 +509,6 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, Oversampled combined grid which minimizes the integration error based on Romberg's method """ - # Save parameters for _adapt_grid - kwargs = dict(max_iter=max_iter, rtol=rtol) # Remove unneeded parts of the grids all_grids = _trim_grids(all_grids, grid_range) @@ -527,8 +520,6 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, combined_grid = np.array([]) # Init with empty array for i_grid, grid in enumerate(all_grids): - estimate = all_estimate[i_grid] - # Get the max_grid_size, considering the other grids # First, remove length already used max_grid_size = max_total_size - combined_grid.size @@ -537,19 +528,23 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, if i_size > i_grid: max_grid_size = max_grid_size - size # Make sure it is at least the size of the native grid. - kwargs['max_grid_size'] = np.max([max_grid_size, all_sizes[i_grid]]) + max_grid_size = np.max([max_grid_size, all_sizes[i_grid]]) # Oversample the grid based on tolerance required - grid, is_converged = _adapt_grid(grid, estimate, **kwargs) + grid, is_converged = _adapt_grid(grid, + all_estimates[i_grid], + max_grid_size=max_grid_size, + max_iter=max_iter, + rtol=rtol) # Update grid sizes all_sizes[i_grid] = grid.size # Check convergence if not is_converged: - msg = 'Precision cannot be garanteed:' - if grid.size < kwargs['max_grid_size']: - msg += (f' smallest subdivision 1/{2 ** kwargs["max_iter"]:2.1e}' + msg = 'Precision cannot be guaranteed:' + if grid.size < max_grid_size: + msg += (f' smallest subdivision 1/{2 ** max_iter:2.1e}' f' was reached for grid index = {i_grid}') else: total_size = np.sum(all_sizes) @@ -558,25 +553,13 @@ def make_combined_adaptive_grid(all_grids, all_estimate, grid_range, msg += f' = {total_size} was reached for grid index = {i_grid}.' log.warning(msg) - # Remove regions already covered in the output grid - if len(combined_grid) > 0: - idx_covered = (np.min(combined_grid) <= grid) - idx_covered &= (grid <= np.max(combined_grid)) - grid = grid[~idx_covered] - # Combine grids combined_grid = np.concatenate([combined_grid, grid]) # Sort values (and keep only unique). - combined_grid = np.unique(combined_grid) - - # Final trim to make sure it respects the range - if grid_range is not None: - idx_in_range = (grid_range[0] <= combined_grid) - idx_in_range &= (combined_grid <= grid_range[-1]) - combined_grid = combined_grid[idx_in_range] - - return combined_grid + # This is necessary because trim_grids allows lowest index of one grid to + # equal highest index of another grid. + return np.unique(combined_grid) def _romberg_diff(b, c, k): @@ -667,7 +650,8 @@ def _difftrap(fct, intervals, numtraps): def _estim_integration_err(grid, fct): - """Estimate the integration error on each intervals + """ + Estimate the integration error on each intervals of the grid using 1rst order Romberg integration. Parameters @@ -676,12 +660,17 @@ def _estim_integration_err(grid, fct): Grid for integration. Each sections of this grid are treated as separate integrals. So if grid has length N; N-1 integrals are tested. + fct: callable Function to be integrated. Returns ------- - err, rel_err: error and relative error of each integrations, with length = length(grid) - 1 + err: + absolute error of each integration, with length = length(grid) - 1 + + rel_err: + relative error of each integration, with length = length(grid) - 1 """ # Change the 1D grid into a 2D set of intervals. @@ -711,14 +700,27 @@ def _estim_integration_err(grid, fct): return err, rel_err -def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): - """Return an irregular oversampled grid needed to reach a +def _adapt_grid(grid, fct, max_grid_size, max_iter=10, rtol=10e-6, atol=1e-6): + """ + Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. The grid is built by subdividing iteratively each intervals that did not reach the required precision. The precision is computed based on the estimate of the integrals using a first order Romberg integration. + TODO: the reliance on rel_err alone is not very robust. (on main, + an absolute tolerance can be specified but is always set to zero.) + for a simple test case, sometimes no matter how many times + you subdivide you still get the same relative error. + See test_adapt_grid in test_atoca_utils.py + The reason seems to be that you aren't comparing to the actual truth, + but instead an imperfect approx of the truth. The rel. difference between + those estimates is not necessarily going to improve even as the absolute + error does. For now, setting absolute error to be a small value, + but it's not obvious in practice how to set this default because it depends + on the units of the output spectrum. + Parameters ---------- grid: array @@ -729,6 +731,9 @@ def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): fct: callable Function to be integrated. Must be a function of `grid` + max_grid_size: int, required. + maximum size of the output grid. + max_iter: int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given @@ -738,8 +743,8 @@ def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): rtol: float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - max_grid_size: int, optional - maximum size of the output grid. Default is None, so no constraint. + atol: float, optional + The desired absolute tolerance. Default is 1e-6. Returns ------- @@ -758,30 +763,27 @@ def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): [1] 'Romberg's method' https://en.wikipedia.org/wiki/Romberg%27s_method """ - # No limit of max_grid_size not given - if max_grid_size is None: - max_grid_size = np.inf - # Init some flags max_size_reached = (grid.size >= max_grid_size) + if max_size_reached: + raise ValueError('max_grid_size is too small for the input grid.') # Iterate until precision is reached or max_iter for _ in range(max_iter): # Estimate error using Romberg integration - err, rel_err = _estim_integration_err(grid, fct) + abs_err, rel_err = _estim_integration_err(grid, fct) # Check where precision is reached - converged = rel_err < rtol + converged = (rel_err < rtol) | (abs_err < atol) is_converged = converged.all() - # Check if max grid size was reached + # Stop iterating if max grid size was reached if max_size_reached or is_converged: - # Then stop iteration break # Intervals that didn't reach the precision will be subdivided - n_oversample = np.full(err.shape, 2, dtype=int) + n_oversample = np.full(rel_err.shape, 2, dtype=int) # No subdivision for the converged ones n_oversample[converged] = 1 @@ -790,6 +792,8 @@ def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): # to reach the maximum size os_grid_size = n_oversample.sum() if os_grid_size > max_grid_size: + max_size_reached = True + # How many nodes can be added to reach max? n_nodes_remaining = max_grid_size - grid.size @@ -797,18 +801,12 @@ def _adapt_grid(grid, fct, max_iter=10, rtol=10e-6, max_grid_size=None): idx_largest_err = np.argsort(rel_err)[-n_nodes_remaining:] # Build new oversample array and assign only largest errors - n_oversample = np.ones(err.shape, dtype=int) + n_oversample = np.ones(rel_err.shape, dtype=int) n_oversample[idx_largest_err] = 2 - # Flag to stop iterations - max_size_reached = True - - # Generate oversampled grid (subdivide) + # Generate oversampled grid (subdivide). Returns sorted and unique grid. grid = oversample_grid(grid, n_os=n_oversample) - # Make sure sorted and unique. - grid = np.unique(grid) - return grid, is_converged @@ -830,61 +828,56 @@ def ThroughputSOSS(wavelength, throughput): Notes ----- - Clamped boundary condition corresponds to a first derivative of zero, - which ensures a smooth curve given zero-padding outside the range. + Throughput is always zero at min, max of wavelength. """ + wavelength = np.sort(wavelength) wl_min, wl_max = np.min(wavelength), np.max(wavelength) - spline = make_interp_spline(wavelength, throughput, k=3, bc_type=("clamped", "clamped")) - - def interpolator(wl): - thru = np.zeros_like(wl) - inside = (wl > wl_min) & (wl < wl_max) - thru[inside] = spline(wl[inside]) - return thru - + throughput[0] = 0.0 + throughput[-1] = 0.0 + interp = make_interp_spline(wavelength, throughput, k=3, bc_type=("clamped", "clamped")) + def interpolator(wv): + wv = np.clip(wv, wl_min, wl_max) + return interp(wv) return interpolator class WebbKernel: # TODO could probably be cleaned-up somewhat, may need further adjustment. - def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels may need to be flipped? - bounds_error=False, fill_value="extrapolate"): - """A handler for the kernel values. + def __init__(self, wave_kernels, kernels, wave_trace, n_pix): # TODO kernels may need to be flipped?) + """ + Initialize the kernel object. Parameters ---------- wave_kernels : array[float] - Kernels for wavelength array. + Wavelength array for the kernel. Must have same shape as kernels. + kernels : array[float] - Kernels for throughput array. - wave_map : array[float] - Wavelength map of the detector. Since WebbPSF returns kernels in - the pixel space, we need a wave_map to convert to wavelength space. - n_os : int - Oversampling of the kernels. + Kernel for throughput array. + Dimensions are (wavelength, oversampled pixels). + Center (~max throughput) of the kernel is at the center of the 2nd axis. + + wave_trace : array[float] + 1-D trace of the detector central wavelengths for the given order. + Since WebbPSF returns kernels in the pixel space, this is used to + convert to wavelength space. + n_pix : int - Length of the kernels in pixels. - bounds_error : bool - If True, raise an error when trying to call the function out of the - interpolation range. If False, the values will be extrapolated. - fill_value : str - How to extrapolate when needed. Only default "extrapolate" - currently implemented. + Number of detector pixels spanned by the kernel. Second axis of kernels + has shape (n_os * n_pix) - (n_os - 1), where n_os is the + spectral oversampling factor. """ + self.n_pix = n_pix - # Mask where wv_map is equal to 0 - wave_map = np.ma.array(wave_map, mask=(wave_map == 0)) - - # Force wv_map to have the red wavelengths - # at the end of the detector - if np.diff(wave_map, axis=-1).mean() < 0: - wave_map = np.flip(wave_map, axis=-1) + # Mask where trace is equal to 0 + wave_trace = np.ma.array(wave_trace, mask=(wave_trace == 0)) - # Number of columns - ncols = wave_map.shape[-1] + # Force trace to have the red wavelengths at the end of the detector + if np.diff(wave_trace).mean() < 0: + wave_trace = np.flip(wave_trace) - # Create oversampled pixel position array - pixels = np.arange(-(n_pix // 2), n_pix // 2 + (1 / n_os), (1 / n_os)) + # Create oversampled pixel position array. Center index should have value 0. + self.pixels = np.linspace(-(n_pix // 2), n_pix // 2, wave_kernels.shape[0]) # `wave_kernel` has only the value of the central wavelength # of the kernel at each points because it's a function @@ -892,27 +885,25 @@ def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels wave_center = wave_kernels[0, :] # Use the wavelength solution to create a mapping between pixels and wavelengths - # First find the all kernels that fall on the detector. - wave_min = np.amin(wave_map[wave_map > 0]) - wave_max = np.amax(wave_map[wave_map > 0]) + wave_min = np.amin(wave_trace[wave_trace > 0]) + wave_max = np.amax(wave_trace[wave_trace > 0]) i_min = np.searchsorted(wave_center, wave_min) i_max = np.searchsorted(wave_center, wave_max) - 1 - # Use the next kernels at each extremities to define the - # boundaries of the interpolation to use in the class - # RectBivariateSpline (at the end) + # i_min, i_max correspond to the min, max indices of the kernel that are represented + # on the detector. Use those to define the boundaries of the interpolation to use + # in the RectBivariateSpline interpolation bbox = [None, None, wave_center[np.maximum(i_min - 1, 0)], wave_center[np.minimum(i_max + 1, len(wave_center) - 1)]] - ####################### # Keep only kernels that fall on the detector. - kernels = kernels[:, i_min:i_max + 1].copy() + self.kernels = kernels[:, i_min:i_max + 1].copy() wave_kernels = wave_kernels[:, i_min:i_max + 1].copy() - wave_center = np.array(wave_kernels[0, :]) + wave_center = np.array(wave_kernels[0]) # Save minimum kernel value (greater than zero) - kernels_min = np.min(kernels[(kernels > 0.0)]) + self.min_value = np.min(self.kernels[(self.kernels > 0.0)]) # Then find the pixel closest to each kernel center # and use the surrounding pixels (columns) @@ -927,56 +918,50 @@ def __init__(self, wave_kernels, kernels, wave_map, n_os, n_pix, # TODO kernels wv = np.ma.masked_all(i_surround.shape) # Closest pixel wv - i_row, i_col = np.unravel_index( - np.argmin(np.abs(wave_map - wv_c)), wave_map.shape - ) + i_col = np.argmin(np.abs(wave_trace - wv_c)) # Update wavelength center value # (take the nearest pixel center value) - wave_center[i_cen] = wave_map[i_row, i_col] + wave_center[i_cen] = wave_trace[i_col] # Surrounding columns index = i_col + i_surround # Make sure it's on the detector - i_good = (index >= 0) & (index < ncols) + i_good = (index >= 0) & (index < wave_trace.size) # Assign wv values - wv[i_good] = wave_map[i_row, index[i_good]] + wv[i_good] = wave_trace[index[i_good]] # Fit n=1 polynomial poly_i = np.polyfit(i_surround[~wv.mask], wv[~wv.mask], 1) # Project on os pixel grid - wave_kernels[:, i_cen] = np.poly1d(poly_i)(pixels) + wave_kernels[:, i_cen] = np.poly1d(poly_i)(self.pixels) # Save coeffs poly.append(poly_i) - # Save attributes - self.n_pix = n_pix - self.n_os = n_os + # Save computed attributes self.wave_kernels = wave_kernels - self.kernels = kernels - self.pixels = pixels self.wave_center = wave_center self.poly = np.array(poly) - self.fill_value = fill_value - self.bounds_error = bounds_error - self.min_value = kernels_min - # 2d Interpolate - self.f_ker = RectBivariateSpline(pixels, wave_center, kernels, bbox=bbox) + # 2D Interpolate + self.f_ker = RectBivariateSpline(self.pixels, self.wave_center, self.kernels, bbox=bbox) def __call__(self, wave, wave_c): - """Returns the kernel value, given the wavelength and the kernel central - wavelength. + """ + Returns the kernel value, given the wavelength and the kernel central + wavelength. Wavelengths that are out of bounds will be extrapolated. Parameters ---------- wave : array[float] Wavelength where the kernel is projected. + wave_c : array[float] Central wavelength of the kernel. + Returns ------- out : array[float] @@ -985,33 +970,16 @@ def __call__(self, wave, wave_c): wave_center = self.wave_center poly = self.poly - fill_value = self.fill_value - bounds_error = self.bounds_error n_wv_c = len(wave_center) - f_ker = self.f_ker - n_pix = self.n_pix - min_value = self.min_value - # ################################# - # First, convert wv value in pixels - # using a linear interpolation - # ################################# + # First, convert wavelength value into pixels using self.poly to interpolate # Find corresponding interval i_wv_c = np.searchsorted(wave_center, wave_c) - 1 - # Deal with values out of bounds - if bounds_error: - message = "Value of wv center out of interpolation range" - log.critical(message) - raise ValueError(message) - elif fill_value == "extrapolate": - i_wv_c[i_wv_c < 0] = 0 - i_wv_c[i_wv_c >= (n_wv_c - 1)] = n_wv_c - 2 - else: - message = f"`fill_value`={fill_value} is not an valid option." - log.critical(message) - raise ValueError(message) + # Extrapolate values out of bounds + i_wv_c[i_wv_c < 0] = 0 + i_wv_c[i_wv_c >= (n_wv_c - 1)] = n_wv_c - 2 # Compute coefficients that interpolate along wv_centers d_wv_c = wave_center[i_wv_c + 1] - wave_center[i_wv_c] @@ -1027,19 +995,13 @@ def __call__(self, wave, wave_c): # Compute pixel values pix = a_pix * wave + b_pix - # ###################################### - # Second, compute kernel value on the - # interpolation grid (pixel x wv_center) - # ###################################### - - webbker = f_ker(pix, wave_c, grid=False) - - # Make sure it's not negative and greater than the min value - webbker = np.clip(webbker, min_value, None) + # Second, compute kernel value on the interpolation grid (pixel x wv_center) + webbker = self.f_ker(pix, wave_c, grid=False) - # and put out-of-range values to zero. - webbker[pix > n_pix // 2] = 0 - webbker[pix < -(n_pix // 2)] = 0 + # Make sure it's not negative and greater than the min value, set pixels outside range to zero + webbker = np.clip(webbker, self.min_value, None) + webbker[pix > self.n_pix // 2] = 0 + webbker[pix < -(self.n_pix // 2)] = 0 return webbker diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 1f1e8d4f37..6bef640635 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -24,6 +24,8 @@ ORDER2_SHORT_CUTOFF = 0.58 +# TODO: replace all .__call__ with just an actual call to the object + def get_ref_file_args(ref_files): """ @@ -92,7 +94,7 @@ def get_ref_file_args(ref_files): # The spectral kernels. speckernel_ref = ref_files['speckernel'] - ovs = speckernel_ref.meta.spectral_oversampling + # ovs = speckernel_ref.meta.spectral_oversampling n_pix = 2 * speckernel_ref.meta.halfwidth + 1 # Take the centroid of each trace as a grid to project the WebbKernel @@ -100,19 +102,20 @@ def get_ref_file_args(ref_files): wave_maps = [wavemap_o1, wavemap_o2] centroid = dict() for wv_map, order in zip(wave_maps, [1, 2]): - # Needs the same number of columns as the detector. Put zeros where not defined. - wv_cent = np.zeros((1, wv_map.shape[1])) + wv_cent = np.zeros((wv_map.shape[1])) + # Get central wavelength as a function of columns col, _, wv = _get_trace_1d(ref_files, order) - wv_cent[:, col] = wv + wv_cent[col] = wv + # Set invalid values to zero idx_invalid = ~np.isfinite(wv_cent) wv_cent[idx_invalid] = 0.0 centroid[order] = wv_cent # Get kernels - kernels_o1 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[1], ovs, n_pix) - kernels_o2 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[2], ovs, n_pix) + kernels_o1 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[1], n_pix) + kernels_o2 = WebbKernel(speckernel_ref.wavelengths, speckernel_ref.kernels, centroid[2], n_pix) # Make sure that the kernels cover the wavelength maps speckernel_wv_range = [np.min(speckernel_ref.wavelengths), np.max(speckernel_ref.wavelengths)] diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index e6f468612e..965be8e055 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -9,7 +9,7 @@ def test_arange_2d(): stops = np.ones(starts.shape)*7 out = au.arange_2d(starts, stops) - bad = 65535 + bad = -1 expected_out = np.array([ [3,4,5,6], [4,5,6,bad], @@ -184,7 +184,7 @@ def test_oversample_irregular(os_factor): @pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) -def test__extrapolate_grid(wave_range): +def test_extrapolate_grid(wave_range): extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) @@ -237,64 +237,224 @@ def test_grid_from_map(wave_map, trace_profile): au.grid_from_map(wave_map, trace_profile, wave_range=[0.5,0.9]) -@pytest.mark.parametrize("n_os", ([4,1], 1)) -def test__get_soss_grid(n_os, wave_map, trace_profile, wave_map_o2, trace_profile_o2): +def xsinx(x): + return x*np.sin(x) + + +def test_estim_integration_error(): """ - wave_map has min, max wavelength of 1.5, 4.0 but throughput makes this 1.7, 4.2 - wave_map_o2 has min, max wavelength of 0.5, 3.0 but throughput makes this 0.9, 3.4 + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + TODO: Find something more meaningful to test here """ - # choices of wave_min, wave_max force extrapolation on low end - # and also makes sure both orders matter - wave_min = 0.55 - wave_max = 4.0 - wave_maps = np.array([wave_map, wave_map_o2]) - trace_profiles = np.array([trace_profile, trace_profile_o2]) - wave_grid = au._get_soss_grid(wave_maps, trace_profiles, wave_min, wave_max, n_os) - - delta_lower = wave_grid[1]-wave_grid[0] - delta_upper = wave_grid[-1]-wave_grid[-2] - - # ensure no duplicates and strictly ascending - assert wave_grid.size == np.unique(wave_grid).size - assert np.all(wave_grid[1:] > wave_grid[:-1]) - - # esnure grid is within bounds but just abutting bounds - assert wave_grid.min() >= wave_min - assert wave_grid.min() <= wave_min + delta_lower - assert wave_grid.max() <= wave_max - assert wave_grid.max() >= wave_max - delta_upper - - # ensure oversample factor changes for different n_os - # this is a bit complicated because the wavelength spacing is nonlinear in wave_map - # by a factor of 2 to begin with, so check that the ratio of the wavelength spacing - # in the upper vs lower end of the wl ranges look approx like what was input, modulo - # the oversample factor of the two orders - if n_os == 1: - n_os = [1,1] - og_spacing_lower = WAVELENGTHS[1]-WAVELENGTHS[0] - og_spacing_upper = WAVELENGTHS[-1]-WAVELENGTHS[-2] - expected_ratio = int((n_os[0]/n_os[1])*(np.around(og_spacing_lower/og_spacing_upper))) - - spacing_lower = np.mean(wave_grid[1:6]-wave_grid[:5]) - spacing_upper = np.mean(wave_grid[-5:]-wave_grid[-6:-1]) - actual_ratio = int(np.around((spacing_lower/spacing_upper))) - - # for n=1 case we expect 2, for n=(4,1) case we expect 8 - assert expected_ratio == actual_ratio + n = 11 + grid = np.linspace(0, np.pi, n) + rel_err = au._estim_integration_err(grid, xsinx) + + assert len(rel_err) == n-1 + assert np.all(rel_err >= 0) + assert np.all(rel_err < 1) + +@pytest.mark.parametrize("max_iter, rtol", [(1,1e-3), (10, 1e-9), (10, 1e-3), (1, 1e-9)]) +def test_adapt_grid(max_iter, rtol): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ -def test__get_soss_grid_bad_inputs(wave_map, trace_profile): + input_grid = np.linspace(0, np.pi, 11) + input_grid_diff = input_grid[1] - input_grid[0] + max_grid_size = 100 + grid, is_converged = au._adapt_grid(input_grid, + xsinx, + max_grid_size, + max_iter=max_iter, + rtol=rtol) + + # ensure grid respects max_grid_size and max_iter in all cases + assert len(grid) <= max_grid_size + grid_diff = grid[1:] - grid[:-1] + assert np.min(grid_diff) >= input_grid_diff/(2**max_iter) + + numerical_integral = np.trapz(xsinx(grid), grid) + + # ensure this converges for at least one of our test cases + if max_iter == 10 and rtol == 1e-3: + assert is_converged + + if is_converged: + # test error of the answer is smaller than rtol + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + # test that success was a stop condition + assert len(grid) < max_grid_size + + # test stop conditions + elif max_iter == 10: + # ensure hitting max_grid_size returns an array of exactly length max_grid_size + assert len(grid) == max_grid_size + elif max_iter == 1: + # ensure hitting max_iter can stop iteration before max_grid_size reached + assert len(grid) <= 2*len(input_grid) + + +def test_adapt_grid_bad_inputs(): with pytest.raises(ValueError): - # test bad input shapes - au._get_soss_grid(wave_map, trace_profile, 0.5, 0.9, 1) + # input grid larger than max_grid_size + au._adapt_grid(np.array([1,2,3]), xsinx, 2) + + +def test_trim_grids(): + + grid_range = (-3, 3) + grid0 = np.linspace(-3, 0, 4) # kept entirely. + grid1 = np.linspace(-3, 0, 16) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(0, 3, 5) # kept from 0 to 3 + grid3 = np.linspace(-4, 4, 5) # removed entirely. Outside of grid_range and the rest is superseded + + all_grids = [grid0, grid1, grid2, grid3] + trimmed_grids = au._trim_grids(all_grids, grid_range) + + assert len(trimmed_grids) == len(all_grids) + assert trimmed_grids[0].size == grid0.size + assert trimmed_grids[1].size == 0 + assert trimmed_grids[2].size == grid2.size + assert trimmed_grids[3].size == 0 + + +def test_make_combined_adaptive_grid(): + """see also tests of _adapt_grid and _trim_grids for more detailed tests""" + + grid_range = (0, np.pi) + grid0 = np.linspace(0, np.pi/2, 6) # kept entirely. + grid1 = np.linspace(0, np.pi/2, 15) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(np.pi/2, np.pi, 11) # kept from pi/2 to pi + + # purposely make same lower index for grid2 as upper index for grid0 to test uniqueness of output + + all_grids = [grid0, grid1, grid2] + all_estimate = [xsinx, xsinx, xsinx] + + rtol = 1e-3 + combined_grid = au.make_combined_adaptive_grid(all_grids, all_estimate, grid_range, + max_iter=10, rtol=rtol, max_total_size=100) + + numerical_integral = np.trapz(xsinx(combined_grid), combined_grid) - wave_maps = np.array([wave_map, wave_map]) - trace_profiles = np.array([trace_profile, trace_profile]) + assert np.unique(combined_grid).size == combined_grid.size + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + +def test_throughput_soss(): + + wavelengths = np.linspace(2,5,10) + throughputs = np.ones_like(wavelengths) + interpolator = au.ThroughputSOSS(wavelengths, throughputs) + + # test that it returns 1 for all wavelengths inside range + interp = interpolator(wavelengths) + assert np.allclose(interp[1:-1], throughputs[1:-1]) + assert interp[0] == 0 + assert interp[-1] == 0 + + # test that it returns 0 for all wavelengths outside range + wavelengths_outside = np.linspace(1,1.5,5) + interp = interpolator(wavelengths_outside) + assert np.all(interp == 0) + + # test ValueError raise for shape mismatch with pytest.raises(ValueError): - # test bad n_os shape - au._get_soss_grid(wave_maps, trace_profiles, 0.5, 0.9, [1,1,1]) + au.ThroughputSOSS(wavelengths, throughputs[:-1]) + + +@pytest.fixture(scope="module") +def kernel_init(): + """ + Toy model of the JWST kernel. The kernel is a triangle function + with maximum at the center, uniform in wavelength. + """ + n_os = 3 + n_pix = 5 # full width of kernel + n_wave = 20 + wave_range = (2.0, 5.0) + wavelengths = np.linspace(*wave_range, n_wave) + kernel_width = n_os*n_pix - (n_os - 1) + ctr_idx = kernel_width//2 + wave_kernel = np.ones((kernel_width, wavelengths.size), dtype=float)*wavelengths[None,:] + triangle_function = ctr_idx - np.abs(ctr_idx - np.arange(0, kernel_width)) + kernel = np.ones((kernel_width, wavelengths.size), dtype=float)*triangle_function[:,None] + kernel/=np.max(kernel) + return wave_kernel, kernel, n_pix + +def test_webb_kernel(kernel_init): + + min_trace = 2.5 + max_trace = 3.5 + n_trace = 100 + wave_trace = np.linspace(min_trace, max_trace, n_trace) + (wave_kernel, kernel, n_pix) = kernel_init + + # instantiate the kernel object + kern = au.WebbKernel(wave_kernel, kernel, wave_trace, n_pix) + + # basic ensure that the input is stored and shapes + assert kern.n_pix == n_pix + assert kern.wave_kernels.shape == kern.kernels.shape + + # test that kernel and wave_kernel have been clipped to only keep wavelengths on detector + assert np.all(kern.wave_kernels >= min_trace) + assert np.all(kern.wave_kernels <= max_trace) + assert kern.wave_kernels.shape[0] == wave_kernel.shape[0] + assert kern.wave_kernels.shape[1] < wave_kernel.shape[1] + + # test that pixels is mirrored around the center and has zero at center + assert kern.pixels.size == wave_kernel.shape[0] + assert np.allclose(kern.pixels + kern.pixels[::-1], 0) + assert kern.pixels[kern.pixels.size//2] == 0 + + # test that wave_center has same shape as wavelength axis of wave_kernel + # but contains values that are in wave_trace + assert kern.wave_center.size == kern.wave_kernels.shape[1] + assert all(np.isin(kern.wave_center, wave_trace)) + + # test min value + assert kern.min_value > 0 + assert np.isin(kern.min_value, kern.kernels) + assert isinstance(kern.min_value, float) + + # test the polynomial fit has the proper shape. hard-coded to a first-order, i.e., linear fit + # since the throughput is constant in wavelength, the slopes should be close to zero + # and the y-intercepts should be close to kern.wave_center + # especially with so few points. just go with 10 percent, should catch egregious changes + assert kern.poly.shape == (kern.wave_kernels.shape[1], 2) + assert np.allclose(kern.poly[:,0], 0, atol=1e-1) + assert np.allclose(kern.poly[:,1], kern.wave_center, atol=1e-1) + + # test interpolation function, which takes in a pixel and a wavelength and returns a throughput + # this should return the triangle function at all wavelengths and zero outside range + wl_test = np.linspace(min_trace, max_trace, 10) + pixels_test = np.array([-3, -1.5, -1, 0, 1, 1.5, 2, 3]) + expected = np.array([0, 0.25, 0.5, 1, 0.5, 0.25, 0, 0]) + interp = kern.f_ker(pixels_test, wl_test) + + assert interp.shape == (pixels_test.size, wl_test.size) + diff = interp[:,1:] - interp[:,:-1] + assert np.allclose(diff, 0) + assert np.allclose(interp[:,0], expected, rtol=1e-3) + + # call the kernel object directly + # this takes a wavelength and a central wavelength of the kernel, + # then converts to pixels to use self.f_ker internally + wl_spacing = wave_trace[1] - wave_trace[0] + assert np.allclose(kern(wl_test, wl_test), 1) + assert np.allclose(kern(wl_test, wl_test - wl_spacing), 0.5) + + # test that clipping to minimum value works right + wl_on_edge = wl_test + (2*wl_spacing - 0.0001) + assert np.allclose(kern(wl_test, wl_on_edge), kern.min_value) + + # both inputs need to be same shape + with pytest.raises(ValueError): + kern(wl_test, wl_test[:-1]) \ No newline at end of file From 20ea48e5d18f62f8176119567b6e77b87f31eabc Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 11 Nov 2024 11:58:36 -0500 Subject: [PATCH 11/35] fix some failing unit tests --- jwst/extract_1d/soss_extract/atoca.py | 19 +- jwst/extract_1d/soss_extract/atoca_utils.py | 287 +++++------------- jwst/extract_1d/soss_extract/soss_extract.py | 2 +- .../soss_extract/tests/test_atoca_utils.py | 77 ++++- 4 files changed, 151 insertions(+), 234 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 94f8acb5cb..6578f6ae90 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -76,18 +76,18 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, If callable, the functions depend on the wavelength. If array, projected on `wave_grid`. - kernels : array, callable or sparse matrix + kernels : callable, sparse matrix, or None. Convolution kernel to be applied on spectrum (f_k) for each orders. - Can be array of the shape (N_ker, N_k_c). Can be a callable with the form f(x, x0) where x0 is the position of the center of the kernel. In this case, it must return a 1D array (len(x)), so a kernel value - for each pairs of (x, x0). If array or callable, + for each pairs of (x, x0). If callable, it will be passed to `convolution.get_c_matrix` function and the `c_kwargs` can be passed to this function. If sparse, the shape has to be (N_k_c, N_k) and it will be used directly. N_ker is the length of the effective kernel and N_k_c is the length of the spectrum (f_k) convolved. + If None, the kernel is set to 1, i.e., do not do any convolution. wave_grid : (N_k) array_like, required. The grid on which f(lambda) will be projected. @@ -244,10 +244,14 @@ def update_throughput(self, throughput): def _create_kernels(self, kernels, c_kwargs): """Make sparse matrix from input kernels + TODO: the only kwarg ever passed in here appears to be thresh, which already gets + handled internally... so can we remove all kwargs handling here? + Parameters ---------- - kernels : array, callable or sparse matrix + kernels : callable, sparse matrix, or None Convolution kernel to be applied on the spectrum (f_k) for each order. + If None, kernel is set to 1, i.e., do not do any convolution. c_kwargs : list of N_ord dictionaries or dictionary, optional Inputs keywords arguments to pass to @@ -261,7 +265,7 @@ def _create_kernels(self, kernels, c_kwargs): c_kwargs = [] for ker in kernels: try: - kwargs_ker = {'thresh': ker.min_value} + kwargs_ker = {'thresh': ker.min_value} except AttributeError: # take the get_c_matrix defaults kwargs_ker = dict() @@ -274,8 +278,9 @@ def _create_kernels(self, kernels, c_kwargs): # Define convolution sparse matrix. kernels_new = [] for i_order, kernel_n in enumerate(kernels): - - if not issparse(kernel_n): + if kernel_n is None: + kernel_n = 1 + elif not issparse(kernel_n): kernel_n = atoca_utils.get_c_matrix(kernel_n, self.wave_grid, i_bounds=self.i_bounds[i_order], **c_kwargs[i_order]) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 85b2ca52b6..5cb2920286 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1006,33 +1006,6 @@ def __call__(self, wave, wave_c): return webbker -def _to_2d(kernel, grid_range): - """Build a 2d kernel array with a constant 1D kernel (input) - - Parameters - ---------- - kernel : array[float] - Input 1D kernel. - grid_range : list[int] - Indices over which convolution is defined on grid. - - Returns - ------- - kernel_2d : array[float] - 2D array of input 1D kernel tiled over axis with - length equal to difference of grid_range values. - """ - - # Assign range where the convolution is defined on the grid - a, b = grid_range - - # Get length of the convolved axis - n_k_c = b - a - - # Return a 2D array with this length - return np.tile(kernel, (n_k_c, 1)).T - - def _get_wings(fct, grid, h_len, i_a, i_b): """Compute values of the kernel at grid[+-h_len] @@ -1043,14 +1016,18 @@ def _get_wings(fct, grid, h_len, i_a, i_b): a grid value and the center of the kernel. fct(grid, center) = kernel grid and center have the same length. + grid : array[float] grid where the kernel is projected + h_len : int Half-length where we compute kernel value. + i_a : int Index of grid axis 0 where to apply convolution. Once the convolution applied, the convolved grid will be equal to grid[i_a:i_b]. + i_b : int index of grid axis 1 where to apply convolution. @@ -1058,6 +1035,7 @@ def _get_wings(fct, grid, h_len, i_a, i_b): ------- left : array[float] Kernel values at left wing. + right : array[float] Kernel values at right wing. """ @@ -1110,14 +1088,18 @@ def _trpz_weight(grid, length, shape, i_a, i_b): ---------- grid : array[float] grid where the integration is projected + length : int length of the kernel + shape : tuple[int] shape of the compact convolution 2d array + i_a : int Index of grid axis 0 where to apply convolution. Once the convolution applied, the convolved grid will be equal to grid[i_a:i_b]. + i_b : int index of grid axis 1 where to apply convolution. @@ -1152,10 +1134,8 @@ def _trpz_weight(grid, length, shape, i_a, i_b): return out -def _fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): +def _fct_to_array(fct, grid, grid_range, thresh): """ - TODO: can't scipy do this? - Build a compact kernel 2d array based on a kernel function and a grid to project the kernel @@ -1166,17 +1146,18 @@ def _fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): a grid value and the center of the kernel. fct(grid, center) = kernel grid and center have the same length. + grid : array[float] Grid where the kernel is projected + grid_range : list[int] or tuple[int] Indices of the grid where to apply the convolution. Once the convolution applied, the convolved grid will be equal to grid[grid_range[0]:grid_range[1]]. - thresh : float, optional - Threshold to cut the kernel wings. If `length` is specified, - `thresh` will be ignored. - length : int, optional - Length of the kernel. Must be odd. + + thresh : float, required + Threshold to define the maximum length of the kernel. + Truncate when `kernel` < `thresh`. Returns ------- @@ -1187,155 +1168,54 @@ def _fct_to_array(fct, grid, grid_range, thresh=1e-5, length=None): # Assign range where the convolution is defined on the grid i_a, i_b = grid_range - # Init with the value at kernel's center - out = fct(grid, grid)[i_a:i_b] - - # Add wings - if length is None: - # Generate a 2D array of the grid iteratively until - # thresh is reached everywhere. - - # Init parameters - length = 1 - h_len = 0 # Half length - - # Add value on each sides until thresh is reached - while True: - # Already update half-length - h_len += 1 - - # Compute next left and right ends of the kernel - left, right = _get_wings(fct, grid, h_len, i_a, i_b) - - # Check if they are all below threshold. - if (left < thresh).all() and (right < thresh).all(): - break # Stop iteration - else: - # Update kernel length - length += 2 - - # Set value to zero if smaller than threshold - left[left < thresh] = 0. - right[right < thresh] = 0. - - # add new values to output - out = np.vstack([left, out, right]) + # Init 2-D array with first dimension length 1, with the value at kernel's center + out = fct(grid, grid)[i_a:i_b][np.newaxis,...] - # Weights due to integration (from the convolution) - weights = _trpz_weight(grid, length, out.shape, i_a, i_b) + # Add wings: Generate a 2D array of the grid iteratively until + # thresh is reached everywhere. + # TODO: surely we can make this faster? avoid a while loop by figuring out + # where threshold will be reached a priori somehow? + length = 1 + h_len = 0 # Half length + while True: + h_len += 1 - elif (length % 2) == 1: # length needs to be odd - # Generate a 2D array of the grid iteratively until - # specified length is reached. + # Compute next left and right ends of the kernel + left, right = _get_wings(fct, grid, h_len, i_a, i_b) - # Compute number of half-length - n_h_len = (length - 1) // 2 + # Check if they are all below threshold. + if (left < thresh).all() and (right < thresh).all(): + break # Stop iteration + else: + # Update kernel length + length += 2 - # Simply iterate to compute needed wings - for h_len in range(1, n_h_len + 1): - # Compute next left and right ends of the kernel - left, right = _get_wings(fct, grid, h_len, i_a, i_b) + # Set value to zero if smaller than threshold + left[left < thresh] = 0. + right[right < thresh] = 0. - # Add new kernel values + # add new values to output out = np.vstack([left, out, right]) - # Weights due to integration (from the convolution) - weights = _trpz_weight(grid, length, out.shape, i_a, i_b) - - else: - msg = "`length` provided to `_fct_to_array` must be odd." - log.critical(msg) - raise ValueError(msg) + # Weights due to integration (from the convolution) + weights = _trpz_weight(grid, length, out.shape, i_a, i_b) - kern_array = (out * weights) - return kern_array + return (out * weights) -def _cut_ker(ker, n_out=None, thresh=None): - """Apply a cut on the convolution matrix boundaries. - - Parameters - ---------- - ker : array[float] - convolution kernel in compact form, so - shape = (N_ker, N_k_convolved) - n_out : int, list[int] or tuple[int] - Number of kernel's grid point to keep on the boundaries. - If an int is given, the same number of points will be - kept on each boundaries of the kernel (left and right). - If 2 elements are given, it corresponds to the left and right - boundaries. - thresh : float - threshold used to determine the boundaries cut. - If n_out is specified, this is ignored. - - Returns - ------ - ker : array[float] - The same kernel matrix as the input ker, but with the cut applied. +def _sparse_c(ker, n_k, i_zero): """ - - # Assign kernel length and number of kernels - n_ker, n_k_c = ker.shape - - # Assign half-length of the kernel - h_len = (n_ker - 1) // 2 - - # Determine n_out with thresh if not given - if n_out is None: - - if thresh is None: - # No cut to apply - return ker - else: - # Find where to cut the kernel according to thresh - i_left = np.where(ker[:, 0] >= thresh)[0][0] - i_right = np.where(ker[:, -1] >= thresh)[0][-1] - - # Make sure it is on the good wing. Take center if not. - i_left = np.minimum(i_left, h_len) - i_right = np.maximum(i_right, h_len) - - # Else, unpack n_out - else: - # Could be a scalar or a 2-elements object) - try: - i_left, i_right = n_out - except TypeError: - i_left, i_right = n_out, n_out - - # Find the position where to cut the kernel - # Make sure it is not out of the kernel grid, - # so i_left >= 0 and i_right <= len(kernel) - i_left = np.maximum(h_len - i_left, 0) - i_right = np.minimum(h_len + i_right, n_ker - 1) - - # Apply the cut - for i_k in range(0, i_left): - # Add condition in case the kernel is larger - # than the grid where it's projected. - if i_k < n_k_c: - ker[:i_left - i_k, i_k] = 0 - - for i_k in range(i_right + 1 - n_ker, 0): - # Add condition in case the kernel is larger - # than the grid where it's projected. - if -i_k <= n_k_c: - ker[i_right - n_ker - i_k:, i_k] = 0 - - return ker - - -def _sparse_c(ker, n_k, i_zero=0): - """Convert a convolution kernel in compact form (N_ker, N_k_convolved) + Convert a convolution kernel in compact form (N_ker, N_k_convolved) to sparse form (N_k_convolved, N_k) Parameters ---------- ker : array[float] - Convolution kernel in compact form, with shape (N_kernel, N_kc) + Convolution kernel with shape (N_kernel, N_kc) + n_k : int Length of the original grid + i_zero : int Position of the first element of the convolved grid in the original grid. @@ -1372,14 +1252,12 @@ def _sparse_c(ker, n_k, i_zero=0): offset.append(i_k) # Build convolution matrix - matrix = diags(diag_val, offset, shape=(n_k_c, n_k), format="csr") + return diags(diag_val, offset, shape=(n_k_c, n_k), format="csr") - return matrix - -def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, - sparse=True, n_out=None, thresh_out=None, **kwargs): - """Return a convolution matrix +def get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5): + """ + Return a convolution matrix Can return a sparse matrix (N_k_convolved, N_k) or a matrix in the compact form (N_ker, N_k_convolved). N_k is the length of the grid on which the convolution @@ -1392,51 +1270,34 @@ def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, Parameters ---------- - kernel: ndarray (1D or 2D), callable + kernel: ndarray (2D) or callable Convolution kernel. Can be already 2D (N_ker, N_k_convolved), giving the kernel for each items of the convolved grid. - Can be 1D (N_ker), so the kernel is the same. Can be a callable + Can be a callable with the form f(x, x0) where x0 is the position of the center of the kernel. Must return a 1D array (len(x)), so a kernel value - for each pairs of (x, x0). If kernel is callable, the additional - kwargs `thresh` and `length` will be used to project the kernel. - grid: one-d-array: + for each pairs of (x, x0). + + grid: 1D np.array: The grid on which the convolution will be applied. For example, if C is the convolution matrix, f_convolved = C.f(grid) - bounds: 2-elements object + + i_bounds: 2-elements object, optional, default None. The bounds of the grid on which the convolution is defined. For example, if bounds = (a,b), then grid_convolved = grid[a <= grid <= b]. - It dictates also the dimension of f_convolved - sparse: bool, optional - return a sparse matrix (N_k_convolved, N_k) if True. - return a matrix (N_ker, N_k_convolved) if False. - n_out: integer or 2-integer object, optional - Specify how to deal with the ends of the convolved grid. - `n_out` points will be used outside from the convolved - grid. Can be different for each ends if 2-elements are given. - thresh_out: float, optional - Specify how to deal with the ends of the convolved grid. - Points with a kernel value less then `thresh_out` will - not be used outside from the convolved grid. + It dictates also the dimension of f_convolved. + If None, the convolution is defined on the whole grid. + thresh: float, optional Only used when `kernel` is callable to define the maximum length of the kernel. Truncate when `kernel` < `thresh` - length: int, optional - Only used when `kernel` is callable to define the maximum - length of the kernel. """ # Define range where the convolution is defined on the grid. - # If `i_bounds` is not specified, try with `bounds`. if i_bounds is None: - - if bounds is None: - a, b = 0, len(grid) - else: - a = np.min(np.where(grid >= bounds[0])[0]) - b = np.max(np.where(grid <= bounds[1])[0]) + 1 + a, b = 0, len(grid) else: # Make sure it is absolute index, not relative @@ -1446,31 +1307,21 @@ def get_c_matrix(kernel, grid, bounds=None, i_bounds=None, norm=True, a, b = i_bounds - # Generate a 2D kernel depending on the input + # Generate a 2D kernel of shape (N_kernel x N_kc) from a callable if callable(kernel): - kernel = _fct_to_array(kernel, grid, [a, b], **kwargs) - elif kernel.ndim == 1: - kernel = _to_2d(kernel, [a, b]) + kernel = _fct_to_array(kernel, grid, [a, b], thresh) - if kernel.ndim != 2: + elif kernel.ndim != 2: msg = ("Input kernel to get_c_matrix must be callable or" - " array with one or two dimensions.") + "2-dimensional array.") log.critical(msg) raise ValueError(msg) - # Kernel should now be a 2-D array (N_kernel x N_kc) - - # Normalize if specified - if norm: - kernel = kernel / np.nansum(kernel, axis=0) - - # Apply cut for kernel at boundaries - kernel = _cut_ker(kernel, n_out, thresh_out) - if sparse: - # Convert to a sparse matrix. - kernel = _sparse_c(kernel, len(grid), a) + # Normalize + kernel = kernel / np.nansum(kernel, axis=0) - return kernel + # Convert to a sparse matrix. + return _sparse_c(kernel, len(grid), a) def _finite_diff(x): diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 6bef640635..6b37661d31 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -217,7 +217,7 @@ def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tr mask = ((spat_pros[1] >= threshold) | mask_trace_profile | scimask) # Init extraction without convolution kernel (so extract the spectrum at order 1 resolution) - ref_file_args = [wave_maps[0]], [spat_pros[0]], [thrpts[0]], [np.array([1.])] + ref_file_args = [wave_maps[0]], [spat_pros[0]], [thrpts[0]], [None] kwargs = {'orders': [1],} engine = ExtractionEngine(*ref_file_args, wave_grid, [mask], **kwargs) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index 965be8e055..d294c4252b 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -249,7 +249,7 @@ def test_estim_integration_error(): n = 11 grid = np.linspace(0, np.pi, n) - rel_err = au._estim_integration_err(grid, xsinx) + err, rel_err = au._estim_integration_err(grid, xsinx) assert len(rel_err) == n-1 assert np.all(rel_err >= 0) @@ -373,7 +373,7 @@ def kernel_init(): with maximum at the center, uniform in wavelength. """ n_os = 3 - n_pix = 5 # full width of kernel + n_pix = 15 # full width of kernel n_wave = 20 wave_range = (2.0, 5.0) wavelengths = np.linspace(*wave_range, n_wave) @@ -434,9 +434,11 @@ def test_webb_kernel(kernel_init): # test interpolation function, which takes in a pixel and a wavelength and returns a throughput # this should return the triangle function at all wavelengths and zero outside range + pix_half = n_pix//2 wl_test = np.linspace(min_trace, max_trace, 10) - pixels_test = np.array([-3, -1.5, -1, 0, 1, 1.5, 2, 3]) - expected = np.array([0, 0.25, 0.5, 1, 0.5, 0.25, 0, 0]) + pixels_test = np.array([-pix_half-1, -pix_half/2, 0, + pix_half/2, pix_half, pix_half+1]) + expected = np.array([0, 0.5, 1, 0.5, 0, 0]) interp = kern.f_ker(pixels_test, wl_test) assert interp.shape == (pixels_test.size, wl_test.size) @@ -447,14 +449,73 @@ def test_webb_kernel(kernel_init): # call the kernel object directly # this takes a wavelength and a central wavelength of the kernel, # then converts to pixels to use self.f_ker internally - wl_spacing = wave_trace[1] - wave_trace[0] + assert kern(wl_test, wl_test).ndim == 1 assert np.allclose(kern(wl_test, wl_test), 1) - assert np.allclose(kern(wl_test, wl_test - wl_spacing), 0.5) + # expect one wl spacing to correspond to one kernel pixel + wl_spacing = wave_trace[1] - wave_trace[0] + expected = 1 - 1/pix_half + assert np.allclose(kern(wl_test, wl_test - wl_spacing), expected) # test that clipping to minimum value works right - wl_on_edge = wl_test + (2*wl_spacing - 0.0001) + wl_on_edge = wl_test + (pix_half*wl_spacing - 0.0001) assert np.allclose(kern(wl_test, wl_on_edge), kern.min_value) # both inputs need to be same shape with pytest.raises(ValueError): - kern(wl_test, wl_test[:-1]) \ No newline at end of file + kern(wl_test, wl_test[:-1]) + + +@pytest.fixture(scope="module") +def kernel(kernel_init): + (wave_kernel, kernel, n_pix) = kernel_init + min_trace = 2.5 + max_trace = 3.5 + n_trace = 100 + wave_trace = np.linspace(min_trace, max_trace, n_trace) + return au.WebbKernel(wave_kernel, kernel, wave_trace, n_pix) + + +def test_fct_to_array(kernel): + + thresh = 1e-5 + grid = np.linspace(2.0, 4.0, 50) + grid_range = [0, grid.size] + + kern_array = au._fct_to_array(kernel, grid, grid_range, thresh) + assert kern_array.ndim == 2 + # test that kernel is set to unity at its center + + #TODO: need to define a more realistic kernel because this one doesn't have any wings + + +def test_sparse_c(): + """Here kernel must be a 2-D array already, of shape (N_ker, N_k_convolved)""" + + kern_array = np.array([]) + + matrix = au._sparse_c(kern_array, n_k, i_zero) + + +def test_get_c_matrix(kernel): + """See also test_fct_to_array and test_sparse_c for more detailed tests + of functions called by this one""" + + + #TODO: what to use for grid? is it / can it be the same as wave_trace? + # I think it can be the same but does not need to be, and would be a better + # test if it were different, because the kernel is an interpolator that was + # created using wavelengths that are included in wave_trace. + matrix = au.get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5) + + # test with WebbKernel as the kernel + # ensure normalized + # ensure sparse + + + # test where input kernel is a 2-D array instead of callable + + + # test where i_bounds is not None + + + # Test invalid kernel input (wrong dimensions) From a85d542e260c653cd4b43796113d6de240d263fa Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 13 Nov 2024 11:55:53 -0500 Subject: [PATCH 12/35] remove unused options and simplify call structure of Tikhonov and TikhoTests --- jwst/extract_1d/soss_extract/atoca.py | 3 +- jwst/extract_1d/soss_extract/atoca_utils.py | 294 +++++------------- jwst/extract_1d/soss_extract/soss_extract.py | 22 +- .../soss_extract/tests/test_atoca_utils.py | 68 +++- 4 files changed, 145 insertions(+), 242 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 6578f6ae90..225fff0c6c 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -632,8 +632,7 @@ def tikho_mat(self): if self._tikho_mat is not None: return self._tikho_mat - fkwargs = {'n_derivative': 1, 'd_grid': True, 'estimate': None, 'pwr_law': 0} - self._tikho_mat = atoca_utils.get_tikho_matrix(self.wave_grid, **fkwargs) + self._tikho_mat = atoca_utils.finite_first_d(self.wave_grid) return self._tikho_mat diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 5cb2920286..1b543ce143 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1173,8 +1173,6 @@ def _fct_to_array(fct, grid, grid_range, thresh): # Add wings: Generate a 2D array of the grid iteratively until # thresh is reached everywhere. - # TODO: surely we can make this faster? avoid a while loop by figuring out - # where threshold will be reached a priori somehow? length = 1 h_len = 0 # Half length while True: @@ -1208,6 +1206,9 @@ def _sparse_c(ker, n_k, i_zero): Convert a convolution kernel in compact form (N_ker, N_k_convolved) to sparse form (N_k_convolved, N_k) + TODO: why is all the formalism for defining the diagonal necessary? why can't csr_matrix be + called directly? there must be a reason, but add documentation! + Parameters ---------- ker : array[float] @@ -1258,13 +1259,12 @@ def _sparse_c(ker, n_k, i_zero): def get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5): """ Return a convolution matrix - Can return a sparse matrix (N_k_convolved, N_k) - or a matrix in the compact form (N_ker, N_k_convolved). + Returns a sparse matrix (N_k_convolved, N_k). N_k is the length of the grid on which the convolution will be applied, N_k_convolved is the length of the grid after convolution and N_ker is the maximum length of - the kernel. If the default sparse matrix option is chosen, - the convolution can be applied on an array f | f = fct(grid) + the kernel. + The convolution can be applied on an array f | f = fct(grid) by a simple matrix multiplication: f_convolved = c_matrix.dot(f) @@ -1339,47 +1339,12 @@ def _finite_diff(x): the result is the same as np.diff(x) """ n_x = len(x) - - # Build matrix diff_matrix = diags([-1.], shape=(n_x - 1, n_x)) diff_matrix += diags([1.], 1, shape=(n_x - 1, n_x)) - return diff_matrix -def _finite_second_d(grid): - """Returns the second derivative operator based on grid - - Parameters - ---------- - grid : array[float] - 1D array where the second derivative will be computed. - - Returns - ------- - second_d : array[float] - Operator to compute the second derivative, so that - f" = second_d.dot(f), where f is a function - projected on `grid`. - """ - - # Finite difference operator - d_matrix = _finite_diff(grid) - - # Delta lambda - d_grid = d_matrix.dot(grid) - - # First derivative operator - first_d = diags(1. / d_grid).dot(d_matrix) - - # Second derivative operator - second_d = _finite_diff(grid[:-1]).dot(first_d) - - # don't forget the delta lambda - return diags(1. / d_grid[:-1]).dot(second_d) - - -def _finite_first_d(grid): +def finite_first_d(grid): """Returns the first derivative operator based on grid Parameters @@ -1390,7 +1355,7 @@ def _finite_first_d(grid): Returns ------- first_d : array[float] - Operator to compute the second derivative, so that + Operator to compute the first derivative, so that f' = first_d.dot(f), where f is a function projected on `grid`. """ @@ -1405,86 +1370,6 @@ def _finite_first_d(grid): return diags(1. / d_grid).dot(d_matrix) -def get_tikho_matrix(grid, n_derivative=1, d_grid=True, estimate=None, pwr_law=0): - """ - TODO: can all this Tikhonov stuff go into the classes? - - Wrapper to return the tikhonov matrix given a grid and the derivative degree. - - Parameters - ---------- - grid : array[float] - 1D grid where the Tikhonov matrix is projected - n_derivative : int, optional - Degree of derivative. Possible values are 1 or 2 - d_grid : bool, optional - Whether to divide the differential operator by the grid differences, - which corresponds to an actual approximation of the derivative or not. - estimate : callable (preferably scipy.interpolate.UnivariateSpline), optional - Estimate of the solution on which the tikhonov matrix is applied. - Must be a function of `grid`. If UnivariateSpline, then the derivatives - are given directly (so best option), otherwise the tikhonov matrix will be - applied to `estimate(grid)`. Note that it is better to use `d_grid=True` - pwr_law: float, optional - Power law applied to the scale differentiated estimate, so the estimate - of tikhonov_matrix.dot(solution). It will be applied as follows: - norm_factor * scale_factor.dot(tikhonov_matrix) - where scale_factor = 1/(estimate_derivative)**pwr_law - and norm_factor = 1/sum(scale_factor) - Returns - ------- - t_mat : array[float] - The tikhonov matrix. - """ - if d_grid: - input_grid = grid - else: - input_grid = np.arange(len(grid)) - - if n_derivative == 1: - t_mat = _finite_first_d(input_grid) - elif n_derivative == 2: - t_mat = _finite_second_d(input_grid) - else: - msg = "`n_derivative` must be 1 or 2." - log.critical(msg) - raise ValueError(msg) - - if estimate is not None: - if hasattr(estimate, 'derivative'): - # Get the derivatives directly from the spline - if n_derivative == 1: - derivative = estimate.derivative(n=n_derivative) - tikho_factor_scale = derivative(grid[:-1]) - elif n_derivative == 2: - derivative = estimate.derivative(n=n_derivative) - tikho_factor_scale = derivative(grid[1:-1]) - else: - # Apply tikho matrix on estimate - tikho_factor_scale = t_mat.dot(estimate(grid)) - - # Make sure all positive - tikho_factor_scale = np.abs(tikho_factor_scale) - # Apply power law - # (similar to 'kunasz1973'?) - tikho_factor_scale = np.power(tikho_factor_scale, -pwr_law) - # Normalize - valid = np.isfinite(tikho_factor_scale) - tikho_factor_scale /= np.sum(tikho_factor_scale[valid]) - - # If some values are not finite, set to the max value - # so it will be more regularized - valid = np.isfinite(tikho_factor_scale) - if not valid.all(): - value = np.max(tikho_factor_scale[valid]) - tikho_factor_scale[~valid] = value - - # Apply to tikhonov matrix - t_mat = diags(tikho_factor_scale).dot(t_mat) - - return t_mat - - def _curvature_finite(factors, log_reg2, log_chi2): """Compute the curvature in log space using finite differences @@ -1492,8 +1377,10 @@ def _curvature_finite(factors, log_reg2, log_chi2): ---------- factors : array[float] Regularisation factors (not in log). + log_reg2 : array[float] norm-2 of the regularisation term (in log10). + log_chi2 : array[float] norm-2 of the chi2 term (in log10). @@ -1501,8 +1388,8 @@ def _curvature_finite(factors, log_reg2, log_chi2): ------- factors : array[float] Sorted and cut version of input factors array. - curvature : array[float] + curvature : array[float] """ # Make sure it is sorted according to the factors idx = np.argsort(factors) @@ -1585,7 +1472,7 @@ def _get_interp_idx_array(idx, relative_range, max_length): return np.arange(*abs_range, 1) -def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): +def _minimize_on_grid(factors, val_to_minimize, interpolate=True, interp_index=[-2,4]): """ Find minimum of a grid using akima spline interpolation to get a finer estimate Parameters @@ -1606,9 +1493,6 @@ def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): The factor with minimized error/curvature. """ - if interp_index is None: - interp_index = [-2, 4] - # Only keep finite values idx_finite = np.isfinite(val_to_minimize) factors = factors[idx_finite] @@ -1650,23 +1534,26 @@ def _minimize_on_grid(factors, val_to_minimize, interpolate, interp_index=None): return min_fac -def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): +def _find_intersect(factors, y_val, thresh, interpolate=True, search_range=[0,3]): """ Find the root of y_val - thresh (so the intersection between thresh and y_val) Parameters ---------- factors : array[float] 1D array of Tikhonov factors for which value array is calculated + y_val : array[float] 1D array of values. + thresh: float Threshold use in 'd_chi2' mode. Find the highest factor where the derivative of the chi2 derivative is below thresh. - interpolate: bool, optional + + interpolate: bool, optional, default True. If True, use interpolation to find a finer minimum; otherwise, return minimum value in array. - search_range : iterable[int], optional + + search_range : iterable[int], optional, default [0,3] Relative range of grid indices around the value to interpolate. - If not specified, defaults to [0,3]. Returns ------- @@ -1675,9 +1562,6 @@ def _find_intersect(factors, y_val, thresh, interpolate, search_range=None): point. """ - if search_range is None: - search_range = [0, 3] - # Only keep finite values idx_finite = np.isfinite(y_val) factors = factors[idx_finite] @@ -1755,12 +1639,18 @@ class TikhoTests(dict): by default. """ - DEFAULT_TRESH_DERIVATIVE = (('chi2', 1e-5), + DEFAULT_THRESH_DERIVATIVE = (('chi2', 1e-5), ('chi2_soft_l1', 1e-4), ('chi2_cauchy', 1e-3)) - def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): + def __init__(self, test_dict, default_chi2='chi2_cauchy'): """ + TODO: always instantiated with a dict, no reason for optional + always uses default chi2 option, no need to make optional or support other options + + Dict always has the following five keys: + 'factors', 'solution', 'error', 'reg', 'grid' + Parameters ---------- test_dict : dict @@ -1771,18 +1661,14 @@ def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): """ # Define the number of data points # (length of the "b" vector in the tikhonov regularisation) - if test_dict is None: - log.warning('Unable to get the number of data points. Setting `n_points` to 1') - n_points = 1 - else: - n_points = len(test_dict['error'][0].squeeze()) + n_points = len(test_dict['error'][0].squeeze()) # Save attributes self.n_points = n_points self.default_chi2 = default_chi2 self.default_thresh = {chi2_type: thresh for (chi2_type, thresh) - in self.DEFAULT_TRESH_DERIVATIVE} + in self.DEFAULT_THRESH_DERIVATIVE} # Initialize so it behaves like a dictionary super().__init__(test_dict) @@ -1795,59 +1681,43 @@ def __init__(self, test_dict=None, default_chi2='chi2_cauchy'): # Save the chi2 self[chi2_type] except KeyError: - self[chi2_type] = self._compute_chi2(loss=loss) + self[chi2_type] = self._compute_chi2(loss) - def _compute_chi2(self, tests=None, loss='linear'): + def _compute_chi2(self, loss): """ TODO: is there a scipy builtin that does this? Calculates the reduced chi squared statistic Parameters ---------- - tests : dict, optional - Dictionary from which we take the error array; if not provided, - self is used - n_points : int, optional - Number of data points; if not provided, self.n_points is used + loss: str + Type of loss function to use. Options are 'linear', 'soft_l1', 'cauchy'. Returns ------- float Sum of the squared error array divided by the number of data points """ - # If not given, take the tests from the object - if tests is None: - tests = self - - # Get the loss function - if isinstance(loss, str): - try: - loss = LOSS_FUNCTIONS[loss] - except KeyError as e: - keys = [key for key in LOSS_FUNCTIONS.keys()] - msg = f'loss={loss} not a valid key. Must be one of {keys} or callable.' - raise e(msg) - elif not callable(loss): - raise ValueError('Invalid value for loss.') + # retrieve loss function + try: + loss = LOSS_FUNCTIONS[loss] + except KeyError as e: + keys = [key for key in LOSS_FUNCTIONS.keys()] + msg = f'loss={loss} not a valid key. Must be one of {keys} or callable.' + raise e(msg) # Compute the reduced chi^2 for all tests - chi2 = np.nanmean(loss(tests['error']**2), axis=-1) + chi2 = np.nanmean(loss(self['error']**2), axis=-1) # Remove residual dimensions - chi2 = chi2.squeeze() + return chi2.squeeze() - return chi2 - def get_chi2_derivative(self, key=None): + def _get_chi2_derivative(self): """ TODO: is there a scipy builtin that does this? Compute derivative of the chi2 with respect to log10(factors) - Parameters - ---------- - key: str - which chi2 is used for computations. Default is self.default_chi2. - Returns ------- factors_leftd : array[float] @@ -1855,8 +1725,7 @@ def get_chi2_derivative(self, key=None): d_chi2 : array[float] derivative of chi squared array with respect to log10(factors) """ - if key is None: - key = self.default_chi2 + key = self.default_chi2 # Compute finite derivative fac_log = np.log10(self['factors']) @@ -1868,27 +1737,24 @@ def get_chi2_derivative(self, key=None): return factors_leftd, d_chi2 - def compute_curvature(self, tests=None, key=None): - - if key is None: - key = self.default_chi2 - # If not given, take the tests from the object - if tests is None: - tests = self + def _compute_curvature(self): + """ + TODO: add docstring + TODO: can this be combined with _curvature_finite? + """ + key = self.default_chi2 # Compute the curvature... # Get the norm-2 of the regularisation term - reg2 = np.nansum(tests['reg'] ** 2, axis=-1) + reg2 = np.nansum(self['reg'] ** 2, axis=-1) - factors, curv = _curvature_finite(tests['factors'], + return _curvature_finite(self['factors'], np.log10(self[key]), np.log10(reg2)) - return factors, curv - def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, - mode='curvature', key=None, thresh=None): + def best_tikho_factor(self, mode='curvature'): """ TODO: why is there a function with identical name in atoca.py ExtractionEngine? this one is called by the other one... @@ -1897,61 +1763,42 @@ def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, It is determined by taking the factor giving the highest logL on the detector or the highest curvature of the l-curve, depending on the chosen mode. + Parameters ---------- - tests : dictionary, optional - Results of tikhonov extraction tests for different factors. - Must have the keys "factors" and "-logl". If not specified, - the tests from self.tikho.tests are used. - interpolate : bool, optional - If True, use spline interpolation to find a finer minimum. - Default is true. - interp_index : list, optional - Relative range of grid indices around the minimum value to - interpolate across. If not specified, defaults to [-2,4]. mode : string How to find the best factor: 'chi2', 'curvature' or 'd_chi2'. - thresh : float - Threshold for use in 'd_chi2' mode. Find the highest factor where - the derivative of the chi2 derivative is below thresh. Returns ------- float Best scale factor as determined by the selected algorithm """ - if key is None: - key = self.default_chi2 - - if thresh is None: - thresh = self.default_thresh[key] - - # Use pre-run tests if not specified - if tests is None: - tests = self + key = self.default_chi2 + thresh = self.default_thresh[key] # Number of factors - n_fac = len(tests['factors']) + n_fac = len(self['factors']) # Determine the mode (what do we minimize?) if mode == 'curvature' and n_fac > 2: # Compute the curvature - factors, curv = tests.compute_curvature() + factors, curv = self._compute_curvature() # Find min factor - best_fac = _minimize_on_grid(factors, curv, interpolate, interp_index) + best_fac = _minimize_on_grid(factors, curv) elif mode == 'chi2': # Simply take the chi2 and factors - factors = tests['factors'] - y_val = tests[key] + factors = self['factors'] + y_val = self[key] # Find min factor - best_fac = _minimize_on_grid(factors, y_val, interpolate, interp_index) + best_fac = _minimize_on_grid(factors, y_val) elif mode == 'd_chi2' and n_fac > 1: # Compute the derivative of the chi2 - factors, y_val = tests.get_chi2_derivative() + factors, y_val = self._get_chi2_derivative() # Remove values for the higher factors that # are not already below thresh. If not _find_intersect @@ -1973,11 +1820,12 @@ def best_tikho_factor(self, tests=None, interpolate=True, interp_index=None, idx = slice(None) # Find intersection with threshold - best_fac = _find_intersect(factors[idx], y_val[idx], thresh, interpolate, interp_index) + best_fac = _find_intersect(factors[idx], y_val[idx], thresh) elif mode in ['curvature', 'd_chi2', 'chi2']: - best_fac = np.max(tests['factors']) - msg = (f'Could not compute {mode} because number of factor={n_fac}. ' + best_fac = np.max(self['factors']) + msg = (f'Could not compute {mode} because number of factors {n_fac} ' + 'is too small for that mode.' f'Setting best factor to max factor: {best_fac:.5e}') log.warning(msg) @@ -2138,10 +1986,12 @@ def test_factors(self, factors): reg = np.array(reg) # Save in a dictionary + print(factors) + print(sln) + print(err) + print(reg) - tests = TikhoTests({'factors': factors, + return TikhoTests({'factors': factors, 'solution': sln, 'error': err, 'reg': reg}) - - return tests diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 6b37661d31..47f7ff3788 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -24,8 +24,6 @@ ORDER2_SHORT_CUTOFF = 0.58 -# TODO: replace all .__call__ with just an actual call to the object - def get_ref_file_args(ref_files): """ @@ -222,7 +220,7 @@ def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tr engine = ExtractionEngine(*ref_file_args, wave_grid, [mask], **kwargs) # Extract estimate - spec_estimate = engine.__call__(scidata_bkg, scierr) + spec_estimate = engine(scidata_bkg, scierr) # Interpolate idx = np.isfinite(spec_estimate) @@ -608,14 +606,14 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, guess_factor = engine.estimate_tikho_factors(estimate) log_guess = np.log10(guess_factor) factors = np.logspace(log_guess - 4, log_guess + 4, 10) - all_tests = engine.get_tikho_tests(factors, data=scidata_bkg, error=scierr) - tikfac= engine.best_tikho_factor(tests=all_tests, fit_mode='all') + all_tests = engine.get_tikho_tests(factors, scidata_bkg, scierr) + tikfac= engine.best_tikho_factor(all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) factors = np.logspace(tikfac - 2, tikfac + 2, 20) - tiktests = engine.get_tikho_tests(factors, data=scidata_bkg, error=scierr) - tikfac = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') + tiktests = engine.get_tikho_tests(factors, scidata_bkg, scierr) + tikfac = engine.best_tikho_factor(tiktests, fit_mode='d_chi2') # Add all theses tests to previous ones all_tests = _append_tiktests(all_tests, tiktests) @@ -637,7 +635,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, log.info('Using a Tikhonov factor of {}'.format(tikfac)) # Run the extract method of the Engine. - f_k = engine.__call__(scidata_bkg, scierr, tikhonov=True, factor=tikfac) + f_k = engine(scidata_bkg, scierr, tikhonov=True, factor=tikfac) # Compute the log-likelihood of the best fit. logl = engine.compute_likelihood(f_k, scidata_bkg, scierr) @@ -811,18 +809,18 @@ def throughput(wavelength): # Find the tikhonov factor. # Initial pass with tikfac_range. factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1] + 8, 10) - all_tests = engine.get_tikho_tests(factors, data=data_order, error=err_order) + all_tests = engine.get_tikho_tests(factors, data_order, err_order) tikfac = engine.best_tikho_factor(tests=all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) factors = np.logspace(tikfac - 2, tikfac + 2, 20) - tiktests = engine.get_tikho_tests(factors, data=data_order, error=err_order) - tikfac = engine.best_tikho_factor(tests=tiktests, fit_mode='d_chi2') + tiktests = engine.get_tikho_tests(factors, data_order, err_order) + tikfac = engine.best_tikho_factor(tiktests, fit_mode='d_chi2') all_tests = _append_tiktests(all_tests, tiktests) # Run the extract method of the Engine. - f_k_final = engine.__call__(data_order, err_order, tikhonov=True, factor=tikfac) + f_k_final = engine(data_order, err_order, tikhonov=True, factor=tikfac) # Save binned spectra in a list of SingleSpecModels for optional output spec_list = [] diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index d294c4252b..7a728f0723 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -373,7 +373,7 @@ def kernel_init(): with maximum at the center, uniform in wavelength. """ n_os = 3 - n_pix = 15 # full width of kernel + n_pix = 25 # full width of kernel n_wave = 20 wave_range = (2.0, 5.0) wavelengths = np.linspace(*wave_range, n_wave) @@ -482,19 +482,38 @@ def test_fct_to_array(kernel): grid_range = [0, grid.size] kern_array = au._fct_to_array(kernel, grid, grid_range, thresh) + + # check shape assert kern_array.ndim == 2 - # test that kernel is set to unity at its center + assert kern_array.shape[1] == grid.size + assert kern_array.shape[0]%2 == 1 - #TODO: need to define a more realistic kernel because this one doesn't have any wings + # test that the max is at the center + kern_slice = kern_array[:,kern_array.shape[1]//2] + assert kern_slice[kern_slice.size//2] == np.max(kern_slice) + # test that weights have been applied at edges + kern_center = kern_array[kern_array.shape[0]//2] + assert np.isclose(kern_center[0], kern_center[1]/2) + assert np.isclose(kern_center[-1], kern_center[-2]/2) + + +@pytest.fixture(scope="module") +def kern_array(kernel): + + return au._fct_to_array(kernel, np.linspace(2.0, 4.0, 50), [0, 50], 1e-5) -def test_sparse_c(): - """Here kernel must be a 2-D array already, of shape (N_ker, N_k_convolved)""" - kern_array = np.array([]) +def test_sparse_c(kern_array): + """Here kernel must be a 2-D array already, of shape (N_ker, N_k_convolved)""" + # test typical case n_k = n_kc and i=0 + n_k = kern_array.shape[1] + i_zero = 0 matrix = au._sparse_c(kern_array, n_k, i_zero) + # TODO: add more here + def test_get_c_matrix(kernel): """See also test_fct_to_array and test_sparse_c for more detailed tests @@ -519,3 +538,40 @@ def test_get_c_matrix(kernel): # Test invalid kernel input (wrong dimensions) + + +def test_finite_first_diff(): + + wave_grid = np.linspace(0, 2*np.pi, 100) + test_0 = np.ones_like(wave_grid) + test_sin = np.sin(wave_grid) + + first_d = au.finite_first_d(wave_grid) + assert first_d.size == (wave_grid.size - 1)*2 + + # test trivial example returning zeros for constant + f0 = first_d.dot(test_0) + assert np.allclose(f0, 0) + + # test derivative of sin returns cos + wave_between = (wave_grid[1:] + wave_grid[:-1])/2 + f_sin = first_d.dot(test_sin) + assert np.allclose(f_sin, np.cos(wave_between), atol=1e-3) + + +@pytest.fixture(scope="module") +def tikhoTests(): + """Make a TikhoTests dictionary""" + + + return au.TikhoTests({'factors': factors, + 'solution': sln, + 'error': err, + 'reg': reg, + 'grid': wave_grid}) + + + +def test_tikho_tests(tikhoTests): + + assert False \ No newline at end of file From a6e448b06704b21eadfc4980c19c88264066959c Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Thu, 21 Nov 2024 10:19:47 -0500 Subject: [PATCH 13/35] simplify indexing in get_w --- jwst/extract_1d/soss_extract/atoca.py | 92 +++++---------------- jwst/extract_1d/soss_extract/atoca_utils.py | 38 ++------- 2 files changed, 30 insertions(+), 100 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 225fff0c6c..39af6a032b 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -161,7 +161,7 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, raise MaskOverlapError(msg) # Update i_bounds based on masked wavelengths - self.i_bnds = self._get_i_bnds() + self.i_bounds = self._get_i_bnds() # if throughput is given as callable, turn it into an array of proper shape self.update_throughput(throughput) @@ -379,6 +379,7 @@ def _get_i_bnds(self): return i_bnds_new + def update_i_bnds(self): """Update the grid limits for the extraction. Needs to be done after modification of the mask @@ -387,7 +388,7 @@ def update_i_bnds(self): # Get old and new boundaries. i_bnds_old = self.i_bounds i_bnds_new = self._get_i_bnds() - print(i_bnds_old, i_bnds_new) + print("i_bounds old, new in update_i_bnds", i_bnds_old, i_bnds_new) for i_order in range(self.n_orders): @@ -914,9 +915,6 @@ def __call__(self, data, error, tikhonov=False, factor=None): There will be only one matrix multiplication: (P/sig).(w.T.lambda.c_n). - TODO: is there any way to avoid the need for this __call__ method given - that it's only a thin wrapper to the Tikhonov class? - Parameters ---------- data : (N, M) array_like @@ -959,7 +957,7 @@ def __call__(self, data, error, tikhonov=False, factor=None): return spectrum - def _get_lo_hi(self, grid, i_order): + def _get_lo_hi(self, grid, wave_p, wave_m, mask): """ Find the lowest (lo) and highest (hi) index of wave_grid for each pixels and orders. @@ -968,8 +966,10 @@ def _get_lo_hi(self, grid, i_order): ---------- grid : array[float] Wave_grid to check. - i_order : int - Order to check values. + wave_p : array[float] + TODO: add here + wave_m : array[float] + TODO: add here Returns ------- @@ -979,29 +979,14 @@ def _get_lo_hi(self, grid, i_order): log.debug('Computing lowest and highest indices of wave_grid.') - # ... order dependent attributes - attrs = ['wave_p', 'wave_m', 'mask_ord'] - wave_p, wave_m, mask_ord = self.get_attributes(*attrs, i_order=i_order) - - # Compute only for valid pixels - wave_p = wave_p[~self.mask] - wave_m = wave_m[~self.mask] - # Find lower (lo) index in the pixel - lo = np.searchsorted(grid, wave_m, side='right') + lo = np.searchsorted(grid, wave_m, side='right') - 1 # Find higher (hi) index in the pixel hi = np.searchsorted(grid, wave_p) - 1 - # Set invalid pixels for this order to lo=-1 and hi=-2 - ma = mask_ord[~self.mask] - lo[ma], hi[ma] = -1, -2 - - # print("grid", np.min(grid), np.max(grid), grid.size) - # grid 0.8515140811891634 2.204728219634399 1408 on PR branch - # grid 0.8480011181377183 2.2057149524019297 1413 on main - # the grid could be the problem! - # try reverting all changes to atoca_utils.grid_from_map and dependencies + # Set invalid pixels negative + lo[mask], hi[mask] = -1, -2 return lo, hi @@ -1024,6 +1009,7 @@ def get_mask_wave(self, i_order): attrs = ['wave_p', 'wave_m', 'i_bounds'] wave_p, wave_m, i_bnds = self.get_attributes(*attrs, i_order=i_order) + print("i_bnds in get_mask_wave", i_bnds) wave_min = self.wave_grid[i_bnds[0]] wave_max = self.wave_grid[i_bnds[1] - 1] @@ -1052,64 +1038,30 @@ def get_w(self, i_order): log.debug('Computing weights and k.') - # Get needed attributes - wave_grid, mask = self.get_attributes('wave_grid', 'mask') - - # ... order dependent attributes + # get order dependent attributes attrs = ['wave_p', 'wave_m', 'mask_ord', 'i_bounds'] wave_p, wave_m, mask_ord, i_bnds = self.get_attributes(*attrs, i_order=i_order) # Use the convolved grid (depends on the order) - wave_grid = wave_grid[i_bnds[0]:i_bnds[1]] + wave_grid = self.wave_grid[i_bnds[0]:i_bnds[1]] # Compute the wavelength coverage of the grid d_grid = np.diff(wave_grid) - # Get lo hi - lo, hi = self._get_lo_hi(wave_grid, i_order) # Get indexes - # Compute only valid pixels - wave_p, wave_m = wave_p[~mask], wave_m[~mask] - ma = mask_ord[~mask] + wave_p, wave_m = wave_p[~self.mask], wave_m[~self.mask] + ma = mask_ord[~self.mask] + + # Get lo hi + lo, hi = self._get_lo_hi(wave_grid, wave_p, wave_m, ma) # Get indexes # Number of used pixels n_i = len(lo) i = np.arange(n_i) - # Define first and last index of wave_grid - # for each pixel - k_first, k_last = -1 * np.ones(n_i), -1 * np.ones(n_i) - - # If lowest value close enough to the exact grid value, - # NOTE: Could be approximately equal to the exact grid - # value. It would look like that. - # >>> lo_dgrid = lo - # >>> lo_dgrid[lo_dgrid==len(d_grid)] = len(d_grid) - 1 - # >>> cond = (grid[lo]-wave_m)/d_grid[lo_dgrid] <= 1.0e-8 - # But let's stick with the exactly equal - cond = (wave_grid[lo] == wave_m) - - # special case (no need for lo_i - 1) - k_first[cond & ~ma] = lo[cond & ~ma] - wave_m[cond & ~ma] = wave_grid[lo[cond & ~ma]] - - # else, need lo_i - 1 - k_first[~cond & ~ma] = lo[~cond & ~ma] - 1 - - # Same situation for highest value. If we follow the note - # above (~=), the code could look like - # >>> cond = (wave_p-grid[hi])/d_grid[hi-1] <= 1.0e-8 - # But let's stick with the exactly equal - cond = (wave_p == wave_grid[hi]) - - # special case (no need for hi_i - 1) - k_last[cond & ~ma] = hi[cond & ~ma] - wave_p[cond & ~ma] = wave_grid[hi[cond & ~ma]] - - # else, need hi_i + 1 - k_last[~cond & ~ma] = hi[~cond & ~ma] + 1 - # Generate array of all k_i. Set to max value of uint16 if not valid - k_n = atoca_utils.arange_2d(k_first, k_last + 1) + k_n = atoca_utils.arange_2d(lo[~ma], hi[~ma]) + print(k_n.shape) + bad = k_n == -1 # Number of valid k per pixel diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 1b543ce143..19e8629dae 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1631,12 +1631,6 @@ class TikhoTests(dict): """ Class to save Tikhonov tests for different factors. All the tests are stored in the attribute `tests` as a dictionary - - Parameters - ---------- - test_dict : dict - Dictionary holding arrays for `factors`, `solution`, `error`, and `reg` - by default. """ DEFAULT_THRESH_DERIVATIVE = (('chi2', 1e-5), @@ -1645,17 +1639,11 @@ class TikhoTests(dict): def __init__(self, test_dict, default_chi2='chi2_cauchy'): """ - TODO: always instantiated with a dict, no reason for optional - always uses default chi2 option, no need to make optional or support other options - - Dict always has the following five keys: - 'factors', 'solution', 'error', 'reg', 'grid' - Parameters ---------- test_dict : dict - Dictionary holding arrays for `factors`, `solution`, `error`, and `reg` - by default. + Dictionary holding arrays for `factors`, `solution`, `error`, `reg`, and `grid`. + default_chi2: string Type of chi2 loss used by default. Options are chi2, chi2_soft_l1, chi2_cauchy. """ @@ -1686,7 +1674,6 @@ def __init__(self, test_dict, default_chi2='chi2_cauchy'): def _compute_chi2(self, loss): """ - TODO: is there a scipy builtin that does this? Calculates the reduced chi squared statistic Parameters @@ -1715,13 +1702,13 @@ def _compute_chi2(self, loss): def _get_chi2_derivative(self): """ - TODO: is there a scipy builtin that does this? Compute derivative of the chi2 with respect to log10(factors) Returns ------- factors_leftd : array[float] factors array, shortened to match length of derivative. + d_chi2 : array[float] derivative of chi squared array with respect to log10(factors) """ @@ -1851,19 +1838,18 @@ class Tikhonov: where gamma is the Tikhonov regularization matrix. """ - def __init__(self, a_mat, b_vec, t_mat, valid=True): + def __init__(self, a_mat, b_vec, t_mat): """ Parameters ---------- a_mat : matrix-like object (2d) matrix A in the system to solve A.x = b + b_vec : vector-like object (1d) vector b in the system to solve A.x = b + t_mat : matrix-like object (2d) Tikhonov regularisation matrix to be applied on b_vec. - valid : bool, optional - If True, solve the system only for valid indices. The - invalid values will be set to np.nan. Default is True. """ # Save input matrix @@ -1884,10 +1870,8 @@ def __init__(self, a_mat, b_vec, t_mat, valid=True): self.idx_valid = idx_valid # Save other attributes - self.valid = valid self.test = None - return def solve(self, factor=1.0): """ @@ -1909,8 +1893,7 @@ def solve(self, factor=1.0): a_mat_2 = self.a_mat_2 result = self.result t_mat_2 = self.t_mat_2 - valid = self.valid - idx_valid = self.idx_valid + idx = self.idx_valid # Matrix gamma squared (with scale factor) gamma_2 = factor ** 2 * t_mat_2 @@ -1921,12 +1904,6 @@ def solve(self, factor=1.0): # Initialize solution solution = np.full(matrix.shape[0], np.nan) - # Consider only valid indices if in valid mode - if valid: - idx = idx_valid - else: - idx = np.full(len(solution), True) - # Solve matrix = matrix[idx, :][:, idx] result = result[idx] @@ -1934,6 +1911,7 @@ def solve(self, factor=1.0): return solution + def test_factors(self, factors): """ Test multiple factors From 0a5441525948a57c592660eefb6ad492b5fd9cd8 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Thu, 21 Nov 2024 12:32:06 -0500 Subject: [PATCH 14/35] some style fixes --- jwst/extract_1d/soss_extract/atoca.py | 31 ++---------- jwst/extract_1d/soss_extract/atoca_utils.py | 47 +------------------ .../soss_extract/soss_boxextract.py | 11 ----- jwst/extract_1d/soss_extract/soss_extract.py | 30 ------------ jwst/extract_1d/soss_extract/soss_syscor.py | 4 -- 5 files changed, 6 insertions(+), 117 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 39af6a032b..1ddae7b439 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -66,16 +66,13 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, A list or array of the central wavelength position for each order on the detector. It has to have the same (N, M) as `data`. - trace_profile : (N_ord, N, M) list or array of 2-D arrays A list or array of the spatial profile for each order on the detector. It has to have the same (N, M) as `data`. - throughput : (N_ord [, N_k]) list of array or callable A list of functions or array of the throughput at each order. If callable, the functions depend on the wavelength. If array, projected on `wave_grid`. - kernels : callable, sparse matrix, or None. Convolution kernel to be applied on spectrum (f_k) for each orders. Can be a callable with the form f(x, x0) where x0 is @@ -88,23 +85,18 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, be used directly. N_ker is the length of the effective kernel and N_k_c is the length of the spectrum (f_k) convolved. If None, the kernel is set to 1, i.e., do not do any convolution. - wave_grid : (N_k) array_like, required. The grid on which f(lambda) will be projected. - mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], required. A list or array of the pixel that need to be used for extraction, for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. - orders : list, optional List of orders considered. Default is orders = [1, 2] - threshold : float, optional: The contribution of any order on a pixel is considered significant if its estimated spatial profile is greater than this threshold value. If it is not properly modeled (not covered by the wavelength grid), it will be masked. Default is 1e-3. - c_kwargs : list of N_ord dictionaries or dictionary, optional Inputs keywords arguments to pass to `convolution.get_c_matrix` function for each order. @@ -186,7 +178,6 @@ def get_attributes(self, *args, i_order=None): ---------- args : str or list[str] All attributes to return. - i_order : None or int, optional Index of order to extract. If specified, it will be applied to all attributes in args, so it cannot @@ -252,7 +243,6 @@ def _create_kernels(self, kernels, c_kwargs): kernels : callable, sparse matrix, or None Convolution kernel to be applied on the spectrum (f_k) for each order. If None, kernel is set to 1, i.e., do not do any convolution. - c_kwargs : list of N_ord dictionaries or dictionary, optional Inputs keywords arguments to pass to `convolution.get_c_matrix` function for each order. @@ -298,7 +288,6 @@ def _get_masks(self): ------- general_mask : array[bool] Mask that combines global_mask, wavelength mask, trace_profile mask - mask_ord : array[bool] Mask applied to each order """ @@ -485,12 +474,10 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): ---------- i_order: integer Label of the order (depending on the initiation of the object). - error: (N, M) array_like or None, optional. Estimate of the error on each pixel. Same shape as `data`. If None, the error is set to 1, which means the method will return b_n instead of b_n/sigma. Default is None. - quick: bool, optional If True, only perform one matrix multiplication instead of the whole system: (P/sig).(w.T.lambda.c_n) @@ -515,6 +502,8 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) + print("i_bounds in get_pixel_mapping", i_bnds) + # Keep only valid pixels (P and sig are still 2-D) # And apply directly 1/sig here (quicker) trace_profile_n = trace_profile_n[~mask] / error[~mask] @@ -546,6 +535,7 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): # Save new pixel mapping matrix. self.pixel_mapping[i_order] = pixel_mapping + print(pixel_mapping.shape) return pixel_mapping @@ -561,7 +551,6 @@ def build_sys(self, data, error): ---------- data : (N, M) array_like A 2-D array of real values representing the detector image. - error : (N, M) array_like Estimate of the error on each pixel. @@ -593,7 +582,6 @@ def get_detector_model(self, data, error): ---------- data : (N, M) array_like A 2-D array of real values representing the detector image. - error: (N, M) array_like Estimate of the error on each pixel. @@ -610,6 +598,7 @@ def get_detector_model(self, data, error): # Initiate with empty matrix n_i = (~self.mask).sum() # n good pixels b_matrix = csr_matrix((n_i, self.n_wavepoints)) + print(b_matrix.shape) # Sum over orders for i_order in range(self.n_orders): @@ -689,10 +678,8 @@ def get_tikho_tests(self, factors, data, error): ---------- factors : 1D list or array-like Factors to be tested. - data : (N, M) array_like A 2-D array of real values representing the detector image. - error : (N, M) array_like Estimate of the error on each pixel. Same shape as `data`. @@ -729,7 +716,6 @@ def best_tikho_factor(self, tests, fit_mode): tests : dictionary Results of Tikhonov extraction tests for different factors. Must have the keys "factors" and "-logl". - fit_mode : string Which mode is used to find the best Tikhonov factor. Options are 'all', 'curvature', 'chi2', 'd_chi2'. If 'all' is chosen, the best of the @@ -810,7 +796,6 @@ def rebuild(self, spectrum, fill_value=0.0): spectrum : callable or array-like flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - fill_value : float or np.nan, optional Pixel value where the detector is masked. Default is 0.0. @@ -849,10 +834,8 @@ def compute_likelihood(self, spectrum, data, error): spectrum : array[float] or callable Flux as a function of wavelength if callable or array of flux values corresponding to self.wave_grid. - data : (N, M) array_like A 2-D array of real values representing the detector image. - error : (N, M) array_like Estimate of the error on each pixel. Same shape as `data`. @@ -919,15 +902,12 @@ def __call__(self, data, error, tikhonov=False, factor=None): ---------- data : (N, M) array_like A 2-D array of real values representing the detector image. - error : (N, M) array_like Estimate of the error on each pixel` Same shape as `data`. - tikhonov : bool, optional Whether to use Tikhonov extraction Default is False. - factor : the Tikhonov factor to use if tikhonov is True Returns @@ -1059,8 +1039,7 @@ def get_w(self, i_order): i = np.arange(n_i) # Generate array of all k_i. Set to max value of uint16 if not valid - k_n = atoca_utils.arange_2d(lo[~ma], hi[~ma]) - print(k_n.shape) + k_n = atoca_utils.arange_2d(lo[~ma]-1, hi[~ma]+1) bad = k_n == -1 diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 19e8629dae..38fa6de05a 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -483,23 +483,18 @@ def make_combined_adaptive_grid(all_grids, all_estimates, grid_range, ---------- all_grids : list[array] List of grid (arrays) to pass to _adapt_grid, in order of importance. - all_estimates : list[callable] List of function (callable) to estimate the precision needed to oversample the grid. Must match the corresponding `grid` in `all_grid`. - grid_range : list[float] Wavelength range the new grid should cover. - max_iter : int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given by delta_grid / 2^max_iter. Needs to be greater than zero. Default is 10. - rtol : float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - max_total_size : int, optional maximum size of the output grid. Default is 1 000 000. @@ -660,7 +655,6 @@ def _estim_integration_err(grid, fct): Grid for integration. Each sections of this grid are treated as separate integrals. So if grid has length N; N-1 integrals are tested. - fct: callable Function to be integrated. @@ -668,7 +662,6 @@ def _estim_integration_err(grid, fct): ------- err: absolute error of each integration, with length = length(grid) - 1 - rel_err: relative error of each integration, with length = length(grid) - 1 """ @@ -727,22 +720,17 @@ def _adapt_grid(grid, fct, max_grid_size, max_iter=10, rtol=10e-6, atol=1e-6): Grid for integration. Each sections of this grid are treated as separate integrals. So if grid has length N; N-1 integrals are optimized. - fct: callable Function to be integrated. Must be a function of `grid` - max_grid_size: int, required. maximum size of the output grid. - max_iter: int, optional Number of times the intervals can be subdivided. The smallest subdivison of the grid if max_iter is reached will then be given by delta_grid / 2^max_iter. Needs to be greater then zero. Default is 10. - rtol: float, optional The desired relative tolerance. Default is 10e-6, so 10 ppm. - atol: float, optional The desired absolute tolerance. Default is 1e-6. @@ -751,7 +739,6 @@ def _adapt_grid(grid, fct, max_grid_size, max_iter=10, rtol=10e-6, atol=1e-6): os_grid : 1D array Oversampled grid which minimizes the integration error based on Romberg's method - convergence_flag: bool Whether the estimated tolerance was reach everywhere or not. @@ -816,7 +803,6 @@ def ThroughputSOSS(wavelength, throughput): ---------- wavelength : array[float] A wavelength array. - throughput : array[float] The throughput values corresponding to the wavelengths. @@ -851,17 +837,14 @@ def __init__(self, wave_kernels, kernels, wave_trace, n_pix): # TODO kernels ma ---------- wave_kernels : array[float] Wavelength array for the kernel. Must have same shape as kernels. - kernels : array[float] Kernel for throughput array. Dimensions are (wavelength, oversampled pixels). Center (~max throughput) of the kernel is at the center of the 2nd axis. - wave_trace : array[float] 1-D trace of the detector central wavelengths for the given order. Since WebbPSF returns kernels in the pixel space, this is used to convert to wavelength space. - n_pix : int Number of detector pixels spanned by the kernel. Second axis of kernels has shape (n_os * n_pix) - (n_os - 1), where n_os is the @@ -958,7 +941,6 @@ def __call__(self, wave, wave_c): ---------- wave : array[float] Wavelength where the kernel is projected. - wave_c : array[float] Central wavelength of the kernel. @@ -1016,18 +998,14 @@ def _get_wings(fct, grid, h_len, i_a, i_b): a grid value and the center of the kernel. fct(grid, center) = kernel grid and center have the same length. - grid : array[float] grid where the kernel is projected - h_len : int Half-length where we compute kernel value. - i_a : int Index of grid axis 0 where to apply convolution. Once the convolution applied, the convolved grid will be equal to grid[i_a:i_b]. - i_b : int index of grid axis 1 where to apply convolution. @@ -1035,7 +1013,6 @@ def _get_wings(fct, grid, h_len, i_a, i_b): ------- left : array[float] Kernel values at left wing. - right : array[float] Kernel values at right wing. """ @@ -1088,18 +1065,14 @@ def _trpz_weight(grid, length, shape, i_a, i_b): ---------- grid : array[float] grid where the integration is projected - length : int length of the kernel - shape : tuple[int] shape of the compact convolution 2d array - i_a : int Index of grid axis 0 where to apply convolution. Once the convolution applied, the convolved grid will be equal to grid[i_a:i_b]. - i_b : int index of grid axis 1 where to apply convolution. @@ -1146,15 +1119,12 @@ def _fct_to_array(fct, grid, grid_range, thresh): a grid value and the center of the kernel. fct(grid, center) = kernel grid and center have the same length. - grid : array[float] Grid where the kernel is projected - grid_range : list[int] or tuple[int] Indices of the grid where to apply the convolution. Once the convolution applied, the convolved grid will be equal to grid[grid_range[0]:grid_range[1]]. - thresh : float, required Threshold to define the maximum length of the kernel. Truncate when `kernel` < `thresh`. @@ -1213,10 +1183,8 @@ def _sparse_c(ker, n_k, i_zero): ---------- ker : array[float] Convolution kernel with shape (N_kernel, N_kc) - n_k : int Length of the original grid - i_zero : int Position of the first element of the convolved grid in the original grid. @@ -1277,19 +1245,16 @@ def get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5): with the form f(x, x0) where x0 is the position of the center of the kernel. Must return a 1D array (len(x)), so a kernel value for each pairs of (x, x0). - grid: 1D np.array: The grid on which the convolution will be applied. For example, if C is the convolution matrix, f_convolved = C.f(grid) - i_bounds: 2-elements object, optional, default None. The bounds of the grid on which the convolution is defined. For example, if bounds = (a,b), then grid_convolved = grid[a <= grid <= b]. It dictates also the dimension of f_convolved. If None, the convolution is defined on the whole grid. - thresh: float, optional Only used when `kernel` is callable to define the maximum length of the kernel. Truncate when `kernel` < `thresh` @@ -1377,10 +1342,8 @@ def _curvature_finite(factors, log_reg2, log_chi2): ---------- factors : array[float] Regularisation factors (not in log). - log_reg2 : array[float] norm-2 of the regularisation term (in log10). - log_chi2 : array[float] norm-2 of the chi2 term (in log10). @@ -1388,8 +1351,8 @@ def _curvature_finite(factors, log_reg2, log_chi2): ------- factors : array[float] Sorted and cut version of input factors array. - curvature : array[float] + TODO: add documentation """ # Make sure it is sorted according to the factors idx = np.argsort(factors) @@ -1540,18 +1503,14 @@ def _find_intersect(factors, y_val, thresh, interpolate=True, search_range=[0,3] ---------- factors : array[float] 1D array of Tikhonov factors for which value array is calculated - y_val : array[float] 1D array of values. - thresh: float Threshold use in 'd_chi2' mode. Find the highest factor where the derivative of the chi2 derivative is below thresh. - interpolate: bool, optional, default True. If True, use interpolation to find a finer minimum; otherwise, return minimum value in array. - search_range : iterable[int], optional, default [0,3] Relative range of grid indices around the value to interpolate. @@ -1643,7 +1602,6 @@ def __init__(self, test_dict, default_chi2='chi2_cauchy'): ---------- test_dict : dict Dictionary holding arrays for `factors`, `solution`, `error`, `reg`, and `grid`. - default_chi2: string Type of chi2 loss used by default. Options are chi2, chi2_soft_l1, chi2_cauchy. """ @@ -1708,7 +1666,6 @@ def _get_chi2_derivative(self): ------- factors_leftd : array[float] factors array, shortened to match length of derivative. - d_chi2 : array[float] derivative of chi squared array with respect to log10(factors) """ @@ -1844,10 +1801,8 @@ def __init__(self, a_mat, b_vec, t_mat): ---------- a_mat : matrix-like object (2d) matrix A in the system to solve A.x = b - b_vec : vector-like object (1d) vector b in the system to solve A.x = b - t_mat : matrix-like object (2d) Tikhonov regularisation matrix to be applied on b_vec. """ diff --git a/jwst/extract_1d/soss_extract/soss_boxextract.py b/jwst/extract_1d/soss_extract/soss_boxextract.py index 67b28dff31..5adcf5bfb9 100644 --- a/jwst/extract_1d/soss_extract/soss_boxextract.py +++ b/jwst/extract_1d/soss_extract/soss_boxextract.py @@ -15,13 +15,10 @@ def get_box_weights(centroid, n_pix, shape, cols): ---------- centroid : array[float] Position of the centroid (in rows). Same shape as `cols` - n_pix : float Width of the extraction box in pixels. - shape : Tuple(int, int) Shape of the output image. (n_row, n_column) - cols : array[int] Column indices of good columns. Used if the centroid is defined for specific columns or a sub-range of columns. @@ -66,13 +63,10 @@ def box_extract(scidata, scierr, scimask, box_weights): ---------- scidata : array[float] 2d array of science data with shape (n_row, n_columns) - scierr : array[float] 2d array of uncertainty map with same shape as scidata - scimask : array[bool] 2d boolean array of masked pixels with same shape as scidata - box_weights : array[float] 2d array of pre-computed weights for box extraction, with same shape as scidata @@ -81,10 +75,8 @@ def box_extract(scidata, scierr, scimask, box_weights): ------- cols : array[int] Indices of extracted columns - flux : array[float] The flux in each column - flux_var : array[float] The variance of the flux in each column """ @@ -146,13 +138,10 @@ def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): ---------- err : 2d array[float] Uncertainty map of the pixels. - data : 2d array[float] Pixel values. - pix_to_estim : 2d array[bool] Map of the pixels where the uncertainty needs to be estimated. - valid_pix : 2d array[bool] Map of valid pixels to be used to find the error empirically. diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 47f7ff3788..13462c7707 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -134,7 +134,6 @@ def _get_trace_1d(ref_files, order): ref_files : dict A dictionary of the reference file DataModels, along with values for subarray and pwcpos, i.e. the pupil wheel position. - order : int The spectral order for which to return the trace parameters. @@ -182,19 +181,14 @@ def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tr ---------- scidata_bkg : array A single background subtracted NIRISS SOSS detector image. - scierr : array The uncertainties corresponding to the detector image. - scimask : array Pixel mask to apply to the detector image. - ref_file_args : tuple A tuple of reference file arguments constructed by get_ref_file_args(). - mask_trace_profile: array[bool] Mask determining the aperture used for extraction. Set to False where the pixel should be extracted. - threshold : float, optional: The pixels with an aperture[order 2] > `threshold` are considered contaminated and will be masked. Default is 1e-4. @@ -238,7 +232,6 @@ def _get_native_grid_from_trace(ref_files, spectral_order): ---------- ref_files: dict A dictionary of the reference file DataModels. - spectral_order: int The spectral order for which to return the trace parameters. @@ -281,10 +274,8 @@ def _get_grid_from_trace(ref_files, spectral_order, n_os): ---------- ref_files: dict A dictionary of the reference file DataModels. - spectral_order: int The spectral order for which to return the trace parameters. - n_os: int or array The oversampling factor of the wavelength grid used when solving for the uncontaminated flux. @@ -485,50 +476,38 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, ---------- scidata_bkg : array[float] A single background subtracted NIRISS SOSS detector image. - scierr : array[float] The uncertainties corresponding to the detector image. - scimask : array[bool] Pixel mask to apply to detector image. - refmask : array[bool] Pixels that should never be reconstructed e.g. the reference pixels. - ref_files : dict A dictionary of the reference file DataModels, along with values for subarray and pwcpos, i.e. the pupil wheel position. - box_weights : dict A dictionary of the weights (for each order) used in the box extraction. The weights for each order are 2d arrays with the same size as the detector. - tikfac : float, optional The Tikhonov regularization factor used when solving for the uncontaminated flux. If not specified, the optimal Tikhonov factor is calculated. - threshold : float The threshold value for using pixels based on the spectral profile. Default value is 1e-4. - n_os : int, optional The oversampling factor of the wavelength grid used when solving for the uncontaminated flux. If not specified, defaults to 2. - wave_grid : str or SossWaveGridModel or None Filename of reference file or SossWaveGridModel containing the wavelength grid used by ATOCA to model each pixel valid pixel of the detector. If not given, the grid is determined based on an estimate of the flux (estimate), the relative tolerance (rtol) required on each pixel model and the maximum grid size (max_grid_size). - estimate : UnivariateSpline or None Estimate of the target flux as a function of wavelength in microns. - rtol : float The relative tolerance needed on a pixel model. It is used to determine the sampling of the soss_wave_grid when not directly given. Default is 1e-3. - max_grid_size : int Maximum grid size allowed. It is used when soss_wave_grid is not directly to make sure the computation time or the memory used stays reasonable. @@ -538,16 +517,12 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, ------- tracemodels : dict Dictionary of the modeled detector images for each order. - tikfac : float Optimal Tikhonov factor used in extraction - logl : float Log likelihood value associated with the Tikhonov factor selected. - wave_grid : 1d array Same as wave_grid input. TODO: this isn't true if input wave_grid is None, update docstring - spec_list : list of SpecModel List of the underlying spectra for each integration and order. The tikhonov tests are also included. @@ -866,23 +841,18 @@ def _extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='m ---------- decontaminated_data : array[float] A single backround subtracted NIRISS SOSS detector image. - scierr : array[float] The uncertainties corresponding to the detector image. - scimask : array[float] Pixel mask to apply to the detector image. - box_weights : dict A dictionary of the weights (for each order) used in the box extraction. The weights for each order are 2d arrays with the same size as the detector. - bad_pix : str How to handle the bad pixels. Options are 'masking' and 'model'. 'masking' will simply mask the bad pixels, such that the number of pixels in each column in the box extraction will not be constant, while the 'model' option uses `tracemodels` to replace the bad pixels. - tracemodels : dict Dictionary of the modeled detector images for each order. diff --git a/jwst/extract_1d/soss_extract/soss_syscor.py b/jwst/extract_1d/soss_extract/soss_syscor.py index f04a8a4e17..0e0ac47c46 100644 --- a/jwst/extract_1d/soss_extract/soss_syscor.py +++ b/jwst/extract_1d/soss_extract/soss_syscor.py @@ -13,10 +13,8 @@ def soss_background(scidata, scimask, bkg_mask): ---------- scidata : array[float] The image of the SOSS trace. - scimask : array[bool] Boolean mask of pixels to be excluded. - bkg_mask : array[bool] Boolean mask of pixels to be excluded because they are in the trace, typically constructed with make_background_mask. @@ -25,7 +23,6 @@ def soss_background(scidata, scimask, bkg_mask): ------- scidata_bkg : array[float] Background-subtracted image - col_bkg : array[float] Column-wise background values """ @@ -65,7 +62,6 @@ def make_background_mask(deepstack, width): deepstack : array[float] Deep image of the trace constructed by combining individual integrations of the observation. - width : int Width, in pixels, of the trace to exclude with the mask (i.e. width/256 for a SUBSTRIP256 observation). From 402f34dbf09a8bb3fcf34dcd468ff9013f28d47b Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 22 Nov 2024 11:12:15 -0500 Subject: [PATCH 15/35] tracking down bug --- jwst/extract_1d/soss_extract/atoca.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 1ddae7b439..842eeedde3 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -159,7 +159,7 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, self.update_throughput(throughput) # turn kernels into sparse matrix - self._create_kernels(kernels, c_kwargs) + self.kernels =self._create_kernels(kernels, c_kwargs) # Compute integration weights. see method self.get_w() for details. self.weights, self.weights_k_idx = self.compute_weights() @@ -269,7 +269,7 @@ def _create_kernels(self, kernels, c_kwargs): kernels_new = [] for i_order, kernel_n in enumerate(kernels): if kernel_n is None: - kernel_n = 1 + kernel_n = np.array([1.0]) elif not issparse(kernel_n): kernel_n = atoca_utils.get_c_matrix(kernel_n, self.wave_grid, i_bounds=self.i_bounds[i_order], @@ -277,7 +277,7 @@ def _create_kernels(self, kernels, c_kwargs): kernels_new.append(kernel_n) - self.kernels = kernels_new + return kernels_new def _get_masks(self): @@ -502,7 +502,8 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) - print("i_bounds in get_pixel_mapping", i_bnds) + print(trace_profile_n.shape, throughput_n.shape, kernel_n.shape, weights_n.shape, i_bnds) + # TODO: why is the kernel shape here not the same as on main? # Keep only valid pixels (P and sig are still 2-D) # And apply directly 1/sig here (quicker) From 2e20b6f011a1aace96094da53a099757cca35e5b Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 22 Nov 2024 15:52:37 -0500 Subject: [PATCH 16/35] bugfix for kernels --- jwst/extract_1d/soss_extract/atoca.py | 94 ++++++++----------- jwst/extract_1d/soss_extract/atoca_utils.py | 50 +++++++++- .../soss_extract/tests/test_atoca_utils.py | 3 + 3 files changed, 88 insertions(+), 59 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 842eeedde3..586f106dc5 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -12,9 +12,7 @@ # General imports. import numpy as np -import warnings from scipy.sparse import issparse, csr_matrix, diags -from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning # Local imports. from . import atoca_utils @@ -159,7 +157,7 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, self.update_throughput(throughput) # turn kernels into sparse matrix - self.kernels =self._create_kernels(kernels, c_kwargs) + self.kernels = self._create_kernels(kernels, c_kwargs) # Compute integration weights. see method self.get_w() for details. self.weights, self.weights_k_idx = self.compute_weights() @@ -270,7 +268,7 @@ def _create_kernels(self, kernels, c_kwargs): for i_order, kernel_n in enumerate(kernels): if kernel_n is None: kernel_n = np.array([1.0]) - elif not issparse(kernel_n): + if not issparse(kernel_n): kernel_n = atoca_utils.get_c_matrix(kernel_n, self.wave_grid, i_bounds=self.i_bounds[i_order], **c_kwargs[i_order]) @@ -369,30 +367,6 @@ def _get_i_bnds(self): return i_bnds_new - def update_i_bnds(self): - """Update the grid limits for the extraction. - Needs to be done after modification of the mask - """ - - # Get old and new boundaries. - i_bnds_old = self.i_bounds - i_bnds_new = self._get_i_bnds() - print("i_bounds old, new in update_i_bnds", i_bnds_old, i_bnds_new) - - for i_order in range(self.n_orders): - - # Take most restrictive lower bound. - low_bnds = [i_bnds_new[i_order][0], i_bnds_old[i_order][0]] - i_bnds_new[i_order][0] = np.max(low_bnds) - - # Take most restrictive upper bound. - up_bnds = [i_bnds_new[i_order][1], i_bnds_old[i_order][1]] - i_bnds_new[i_order][1] = np.min(up_bnds) - - # Update attribute. - self.i_bounds = i_bnds_new - - def wave_grid_c(self, i_order): """Return wave_grid for the convolved flux at a given order. """ @@ -502,9 +476,6 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): attrs = ['trace_profile', 'throughput', 'kernels', 'weights', 'i_bounds'] trace_profile_n, throughput_n, kernel_n, weights_n, i_bnds = self.get_attributes(*attrs, i_order=i_order) - print(trace_profile_n.shape, throughput_n.shape, kernel_n.shape, weights_n.shape, i_bnds) - # TODO: why is the kernel shape here not the same as on main? - # Keep only valid pixels (P and sig are still 2-D) # And apply directly 1/sig here (quicker) trace_profile_n = trace_profile_n[~mask] / error[~mask] @@ -536,7 +507,6 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): # Save new pixel mapping matrix. self.pixel_mapping[i_order] = pixel_mapping - print(pixel_mapping.shape) return pixel_mapping @@ -599,7 +569,6 @@ def get_detector_model(self, data, error): # Initiate with empty matrix n_i = (~self.mask).sum() # n good pixels b_matrix = csr_matrix((n_i, self.n_wavepoints)) - print(b_matrix.shape) # Sum over orders for i_order in range(self.n_orders): @@ -870,16 +839,8 @@ def _solve(matrix, result): # Only solve for valid indices, i.e. wavelengths that are # covered by the pixels on the detector. # It will be a singular matrix otherwise. - with warnings.catch_warnings(): - warnings.filterwarnings(action='error', category=MatrixRankWarning) - try: - sln[idx] = spsolve(matrix[idx, :][:, idx], result[idx]) - except MatrixRankWarning: - # on rare occasions spsolve's approximation of the matrix is not appropriate - # and fails on good input data. revert to different solver - log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.') - sln[idx] = lsqr(matrix[idx, :][:, idx], result[idx])[0] - + matrix = matrix[idx, :][:, idx] + sln[idx] = atoca_utils.try_solve_two_methods(matrix, result[idx]) return sln @staticmethod @@ -990,7 +951,6 @@ def get_mask_wave(self, i_order): attrs = ['wave_p', 'wave_m', 'i_bounds'] wave_p, wave_m, i_bnds = self.get_attributes(*attrs, i_order=i_order) - print("i_bnds in get_mask_wave", i_bnds) wave_min = self.wave_grid[i_bnds[0]] wave_max = self.wave_grid[i_bnds[1] - 1] @@ -1039,9 +999,40 @@ def get_w(self, i_order): n_i = len(lo) i = np.arange(n_i) - # Generate array of all k_i. Set to max value of uint16 if not valid - k_n = atoca_utils.arange_2d(lo[~ma]-1, hi[~ma]+1) + # Define first and last index of wave_grid for each pixel + k_first, k_last = -1 * np.ones(n_i), -1 * np.ones(n_i) + + # If lowest value close enough to the exact grid value, + # NOTE: Could be approximately equal to the exact grid + # value. It would look like that. + # >>> lo_dgrid = lo + # >>> lo_dgrid[lo_dgrid==len(d_grid)] = len(d_grid) - 1 + # >>> cond = (grid[lo]-wave_m)/d_grid[lo_dgrid] <= 1.0e-8 + # But let's stick with the exactly equal + cond = (wave_grid[lo] == wave_m) + + # special case (no need for lo_i - 1) + k_first[cond & ~ma] = lo[cond & ~ma] + wave_m[cond & ~ma] = wave_grid[lo[cond & ~ma]] + + # else, need lo_i - 1 + k_first[~cond & ~ma] = lo[~cond & ~ma] - 1 + + # Same situation for highest value. If we follow the note + # above (~=), the code could look like + # >>> cond = (wave_p-grid[hi])/d_grid[hi-1] <= 1.0e-8 + # But let's stick with the exactly equal + cond = (wave_p == wave_grid[hi]) + # special case (no need for hi_i - 1) + k_last[cond & ~ma] = hi[cond & ~ma] + wave_p[cond & ~ma] = wave_grid[hi[cond & ~ma]] + + # else, need hi_i + k_last[~cond & ~ma] = hi[~cond & ~ma] + + # Generate array of all k_i. Set to -1 if not valid + k_n = atoca_utils.arange_2d(k_first, k_last + 1) bad = k_n == -1 # Number of valid k per pixel @@ -1050,19 +1041,16 @@ def get_w(self, i_order): # Compute array of all w_i. Set to np.nan if not valid # Initialize w_n = np.zeros(k_n.shape, dtype=float) - #################### + #################### # 4 different cases #################### - #################### # Valid for every cases w_n[:, 0] = wave_grid[k_n[:, 1]] - wave_m w_n[i, n_k - 1] = wave_p - wave_grid[k_n[i, n_k - 2]] - ################## # Case 1, n_k == 2 - ################## case = (n_k == 2) & ~ma if case.any(): @@ -1081,9 +1069,7 @@ def get_w(self, i_order): part2 = d_grid[k_n[case, 0]] w_n[case, :] *= (part1 / part2)[:, None] - ################## # Case 2, n_k >= 3 - ################## case = (n_k >= 3) & ~ma if case.any(): @@ -1109,9 +1095,7 @@ def get_w(self, i_order): w_n[cond, n_ki - 1] *= (nume1 / deno) w_n[cond, n_ki - 2] += (nume1 * nume2 / deno) - ################## # Case 3, n_k >= 4 - ################## case = (n_k >= 4) & ~ma if case.any(): log.debug('n_k = 4 in get_w().') @@ -1120,9 +1104,7 @@ def get_w(self, i_order): w_n[case, n_ki - 2] += (wave_grid[k_n[case, n_ki - 2]] - wave_grid[k_n[case, n_ki - 3]]) - ################## # Case 4, n_k > 4 - ################## case = (n_k > 4) & ~ma if case.any(): log.debug('n_k > 4 in get_w().') diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 38fa6de05a..9e9971f482 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -7,8 +7,9 @@ """ import numpy as np +import warnings from scipy.sparse import diags, csr_matrix -from scipy.sparse.linalg import spsolve +from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning from scipy.interpolate import interp1d, RectBivariateSpline, Akima1DInterpolator from scipy.optimize import minimize_scalar, brentq from scipy.interpolate import make_interp_spline @@ -988,6 +989,33 @@ def __call__(self, wave, wave_c): return webbker +def _constant_kernel_to_2d(c, grid_range): + """Build a 2d kernel array with a constant 1D kernel as input + + Parameters + ---------- + c : float or size-1 np.ndarray + Constant value to expand into a 2-D kernel + grid_range : list[int] + Indices over which convolution is defined on grid. + + Returns + ------- + kernel_2d : array[float] + 2D array of input 1D kernel tiled over axis with + length equal to difference of grid_range values. + """ + + # Assign range where the convolution is defined on the grid + a, b = grid_range + + # Get length of the convolved axis + n_k_c = b - a + + # Return a 2D array with this length + return np.tile(np.atleast_1d(c), (n_k_c, 1)).T + + def _get_wings(fct, grid, h_len, i_a, i_b): """Compute values of the kernel at grid[+-h_len] @@ -1272,10 +1300,13 @@ def get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5): a, b = i_bounds - # Generate a 2D kernel of shape (N_kernel x N_kc) from a callable + # Generate a 2D kernel of shape (N_kernel x N_kc) if callable(kernel): kernel = _fct_to_array(kernel, grid, [a, b], thresh) + elif kernel.size == 1: + kernel = _constant_kernel_to_2d(kernel, [a, b]) + elif kernel.ndim != 2: msg = ("Input kernel to get_c_matrix must be callable or" "2-dimensional array.") @@ -1783,6 +1814,19 @@ def best_tikho_factor(self, mode='curvature'): return best_fac +def try_solve_two_methods(matrix, result): + + with warnings.catch_warnings(): + warnings.filterwarnings(action='error', category=MatrixRankWarning) + try: + return spsolve(matrix, result) + except MatrixRankWarning: + # on rare occasions spsolve's approximation of the matrix is not appropriate + # and fails on good input data. revert to different solver + log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.') + return lsqr(matrix, result)[0] + + class Tikhonov: """ TODO: can we avoid all of this by using scipy.optimize.least_squares @@ -1862,7 +1906,7 @@ def solve(self, factor=1.0): # Solve matrix = matrix[idx, :][:, idx] result = result[idx] - solution[idx] = spsolve(matrix, result) + solution[idx] = try_solve_two_methods(matrix, result) return solution diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index 7a728f0723..126fa4c368 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -534,6 +534,9 @@ def test_get_c_matrix(kernel): # test where input kernel is a 2-D array instead of callable + # test where input kernel is size 1 + + # test where i_bounds is not None From 53e63494350a4b509cbec489da6af78cc5efc798 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 22 Nov 2024 16:05:07 -0500 Subject: [PATCH 17/35] trying to fix results is nan bug --- jwst/extract_1d/soss_extract/atoca_utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 9e9971f482..cc55e52ff8 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1857,16 +1857,12 @@ def __init__(self, a_mat, b_vec, t_mat): self.t_mat = t_mat # Pre-compute some matrix for the linear system to solve - t_mat_2 = (t_mat.T).dot(t_mat) # squared tikhonov matrix - a_mat_2 = a_mat.T.dot(a_mat) # squared model matrix - result = (a_mat.T).dot(b_vec.T) - idx_valid = (result.toarray() != 0).squeeze() # valid indices to use if `valid` is True - - # Save pre-computed matrix - self.t_mat_2 = t_mat_2 - self.a_mat_2 = a_mat_2 - self.result = result - self.idx_valid = idx_valid + self.t_mat_2 = (t_mat.T).dot(t_mat) # squared tikhonov matrix + self.a_mat_2 = a_mat.T.dot(a_mat) # squared model matrix + self.result = (a_mat.T).dot(b_vec.T) + # here self.result is all NaNs... why? + print(np.sum(~np.isnan(self.result.toarray()))) + self.idx_valid = (self.result.toarray() != 0).squeeze() # valid indices to use if `valid` is True # Save other attributes self.test = None @@ -1906,6 +1902,9 @@ def solve(self, factor=1.0): # Solve matrix = matrix[idx, :][:, idx] result = result[idx] + print(type(matrix)) + print(type(result)) + print(result.nanmin(), result.nanmax()) solution[idx] = try_solve_two_methods(matrix, result) return solution From 056424b76d87c41d85bb895c132be4905d4bca9e Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 22 Nov 2024 17:40:46 -0500 Subject: [PATCH 18/35] still trying to bugfix mask --- jwst/extract_1d/soss_extract/atoca.py | 34 ++++++++++----------- jwst/extract_1d/soss_extract/atoca_utils.py | 10 +----- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 586f106dc5..ddccf1a983 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -156,6 +156,9 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, # if throughput is given as callable, turn it into an array of proper shape self.update_throughput(throughput) + # Re-build global mask and masks for each orders + self.mask, self.mask_ord = self._get_masks() + # turn kernels into sparse matrix self.kernels = self._create_kernels(kernels, c_kwargs) @@ -318,24 +321,6 @@ def _get_masks(self): return general_mask, mask_ord - def update_mask(self, mask): - """Update `mask` attribute by combining the `general_mask` - attribute with the input `mask`. Every time the mask is - changed, the integration weights need to be recomputed - since the pixels change. - - Parameters - ---------- - mask : array[bool] - New mask to be combined with internal general_mask and - saved in self.mask. - """ - self.mask = (self.general_mask | mask) - - # Re-compute weights - self.weights, self.weights_k_idx = self.compute_weights() - - def _get_i_bnds(self): """Define wavelength boundaries for each order using the order's mask and the wavelength map. @@ -579,6 +564,19 @@ def get_detector_model(self, data, error): # Build detector pixels' array # Take only valid pixels and apply `error` on data data = data[~self.mask] / error[~self.mask] + print("sum of nans in data after mask", np.sum(np.isnan(data))) + print("n_i", n_i) + + import matplotlib.pyplot as plt + plt.imshow(self.mask, origin="lower") + plt.show() + + # FIXME: data[~self.mask] still contains NaNs in some cases + # but this is not the case on main + # first pass through, looking at just order 1, it appears self.mask has + # all the NaNs from the data already in it + # second pass through, mask is different and now it doesn't contain the NaNs + # in the data return b_matrix, csr_matrix(data) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index cc55e52ff8..152ad8b80c 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1861,7 +1861,7 @@ def __init__(self, a_mat, b_vec, t_mat): self.a_mat_2 = a_mat.T.dot(a_mat) # squared model matrix self.result = (a_mat.T).dot(b_vec.T) # here self.result is all NaNs... why? - print(np.sum(~np.isnan(self.result.toarray()))) + print("self.result is all nans", np.any(~np.isnan(self.result.toarray()))) self.idx_valid = (self.result.toarray() != 0).squeeze() # valid indices to use if `valid` is True # Save other attributes @@ -1902,9 +1902,6 @@ def solve(self, factor=1.0): # Solve matrix = matrix[idx, :][:, idx] result = result[idx] - print(type(matrix)) - print(type(result)) - print(result.nanmin(), result.nanmax()) solution[idx] = try_solve_two_methods(matrix, result) return solution @@ -1962,11 +1959,6 @@ def test_factors(self, factors): reg = np.array(reg) # Save in a dictionary - print(factors) - print(sln) - print(err) - print(reg) - return TikhoTests({'factors': factors, 'solution': sln, 'error': err, From da9c2e2c0919fdbb053733d8bebd0a97e85bdb0a Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 22 Nov 2024 17:51:36 -0500 Subject: [PATCH 19/35] add note --- jwst/extract_1d/soss_extract/atoca.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index ddccf1a983..4cda7f3294 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -576,7 +576,8 @@ def get_detector_model(self, data, error): # first pass through, looking at just order 1, it appears self.mask has # all the NaNs from the data already in it # second pass through, mask is different and now it doesn't contain the NaNs - # in the data + # that are present in the data, just the outline of the trace + # on main it appears this is somehow updated return b_matrix, csr_matrix(data) From b80b6fb2ba3e9f4e45361572750e167f6d8bf800 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 26 Nov 2024 14:38:33 -0500 Subject: [PATCH 20/35] adding tests of engine --- jwst/extract_1d/soss_extract/atoca.py | 49 +++-- jwst/extract_1d/soss_extract/soss_extract.py | 1 + .../soss_extract/tests/test_atoca.py | 175 ++++++++++++++++-- .../soss_extract/tests/test_atoca_utils.py | 2 + 4 files changed, 189 insertions(+), 38 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 4cda7f3294..cf4e248c67 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -56,6 +56,7 @@ class ExtractionEngine: def __init__(self, wave_map, trace_profile, throughput, kernels, wave_grid, mask_trace_profile, + global_mask=None, orders=[1,2], threshold=1e-3, c_kwargs=None): """ Parameters @@ -88,6 +89,9 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], required. A list or array of the pixel that need to be used for extraction, for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`. + global_mask : (N, M) array_like boolean, optional + Boolean Mask of the detector pixels to mask for every extraction, e.g. bad pixels. + Should not be related to a specific order (if so, use `mask_trace_profile` instead). orders : list, optional List of orders considered. Default is orders = [1, 2] threshold : float, optional: @@ -122,8 +126,8 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, lp, lm = atoca_utils.get_wave_p_or_m(wave) wave_p.append(lp) wave_m.append(lm) - self.wave_p = np.array(wave_p).astype(self.dtype) - self.wave_m = np.array(wave_m).astype(self.dtype) + self.wave_p = np.array(wave_p, dtype=self.dtype) + self.wave_m = np.array(wave_m, dtype=self.dtype) # Set orders and ensure that the number of orders is consistent with wave_map length self.orders = orders @@ -138,7 +142,7 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, self.i_bounds = [[0, len(self.wave_grid)] for _ in range(self.n_orders)] # Estimate a global mask and masks for each orders - self.mask, self.mask_ord = self._get_masks() + self.mask, self.mask_ord = self._get_masks(global_mask) # Save mask here as the general mask, since `mask` attribute can be changed. self.general_mask = self.mask.copy() @@ -153,11 +157,12 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, # Update i_bounds based on masked wavelengths self.i_bounds = self._get_i_bnds() - # if throughput is given as callable, turn it into an array of proper shape + # if throughput is given as callable, turn it into an array + # with shape (n_ord, wave_grid.size) self.update_throughput(throughput) # Re-build global mask and masks for each orders - self.mask, self.mask_ord = self._get_masks() + self.mask, self.mask_ord = self._get_masks(global_mask) # turn kernels into sparse matrix self.kernels = self._create_kernels(kernels, c_kwargs) @@ -281,10 +286,15 @@ def _create_kernels(self, kernels, c_kwargs): return kernels_new - def _get_masks(self): + def _get_masks(self, global_mask): """Compute a general mask on the detector and for each order. Depends on the trace profile and the wavelength grid. + Parameters + ---------- + global_mask : array[bool] + Boolean mask of the detector pixels to mask for every extraction. + Returns ------- general_mask : array[bool] @@ -295,14 +305,18 @@ def _get_masks(self): # Get needed attributes args = ('threshold', 'n_orders', 'mask_trace_profile', 'trace_profile') - needed_attr = self.get_attributes(*args) - threshold, n_orders, mask_trace_profile, trace_profile = needed_attr + threshold, n_orders, mask_trace_profile, trace_profile = self.get_attributes(*args) # Mask pixels not covered by the wavelength grid. mask_wave = np.array([self.get_mask_wave(i_order) for i_order in range(n_orders)]) - # Apply user defined mask. - mask_ord = np.any([mask_trace_profile, mask_wave], axis=0) + # combine trace profile mask with wavelength cutoff mask + # and apply detector bad pixel mask if specified + if global_mask is None: + mask_ord = np.any([mask_trace_profile, mask_wave], axis=0) + else: + mask = [global_mask for _ in range(n_orders)] # For each orders + mask_ord = np.any([mask_trace_profile, mask_wave, mask], axis=0) # Find pixels that are masked in each order. general_mask = np.all(mask_ord, axis=0) @@ -528,7 +542,6 @@ def build_sys(self, data, error): def get_detector_model(self, data, error): """ - TODO: are mask, trace_profile, throughput ever passed in here? Get the linear model of the detector pixel, B.dot(flux) = pixels TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. There will be only one matrix multiplication: @@ -564,20 +577,6 @@ def get_detector_model(self, data, error): # Build detector pixels' array # Take only valid pixels and apply `error` on data data = data[~self.mask] / error[~self.mask] - print("sum of nans in data after mask", np.sum(np.isnan(data))) - print("n_i", n_i) - - import matplotlib.pyplot as plt - plt.imshow(self.mask, origin="lower") - plt.show() - - # FIXME: data[~self.mask] still contains NaNs in some cases - # but this is not the case on main - # first pass through, looking at just order 1, it appears self.mask has - # all the NaNs from the data already in it - # second pass through, mask is different and now it doesn't contain the NaNs - # that are present in the data, just the outline of the trace - # on main it appears this is somehow updated return b_matrix, csr_matrix(data) diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 13462c7707..16e6d46bc5 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -568,6 +568,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, engine = ExtractionEngine(*ref_file_args, wave_grid=wave_grid, mask_trace_profile=mask_trace_profile, + global_mask=scimask, threshold=threshold, c_kwargs=c_kwargs) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index b556e83265..fd1aa0f360 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -1,35 +1,184 @@ import pytest import numpy as np +from functools import partial from jwst.extract_1d.soss_extract import atoca @pytest.fixture(scope="module") def wave_map(): - shp = (100, 30) - wave_ord1 = np.linspace(2.8, 0.8, shp[0]) + """Trying for a roughly factor-of-10 shrink on each axis + but otherwise relatively faithful reproduction of real detector behavior""" + shp = (25, 200) + wave_ord1 = np.linspace(2.8, 0.8, shp[1]) + wave_ord1 = np.ones(shp)*wave_ord1[np.newaxis, :] - pass + wave_ord2 = np.linspace(1.4, 0.5, shp[1]) + wave_ord2 = np.ones(shp)*wave_ord2[np.newaxis,:] + wave_ord2[:,190:] = 0.0 + + return [wave_ord1, wave_ord2] @pytest.fixture(scope="module") def trace_profile(wave_map): - pass + """ + order 2 is partially on top of, partially not on top of order 1 + give order 2 some slope to simulate that + """ + # order 1 + shp = wave_map[0].shape + ord1 = np.zeros((shp[0])) + ord1[3:9] = 0.2 + ord1[2] = 0.1 + ord1[9] = 0.1 + profile_ord1 = np.ones(shp)*ord1[:, np.newaxis] + + # order 2 + yy, xx = np.meshgrid(np.arange(shp[0]), np.arange(shp[1])) + yy = yy.astype(np.float32) - xx.astype(np.float32)*0.08 + yy = yy.T + + profile_ord2 = np.zeros_like(yy) + full = (yy >= 3) & (yy < 9) + half0 = (yy >= 9) & (yy < 11) + half1 = (yy >= 1) & (yy < 3) + profile_ord2[full] = 0.2 + profile_ord2[half0] = 0.1 + profile_ord2[half1] = 0.1 + + return [profile_ord1, profile_ord2] + + +@pytest.fixture(scope="module") +def wave_grid(): + """wave_grid has smaller spacings in some places than others + and is not backwards order like the wave map + Two duplicates are in there on purpose for testing""" + lo0 = np.linspace(0.7, 1.2, 6) + hi = np.linspace(1.2, 1.7, 19) + lo2 = np.linspace(1.7, 2.8, 12) + return np.concatenate([lo0, hi, lo2]) + @pytest.fixture(scope="module") -def throughput(wave_map): - pass +def throughput(): + """make a triangle function for each order but with different peak wavelength + """ + + def filter_function(wl, wl_max): + """Set free parameters to roughly mimic throughput functions on main""" + maxthru = 0.4 + thresh = 0.01 + scaling = 0.3 + dist = np.abs(wl - wl_max) + thru = maxthru - dist*scaling + thru[thru0] = 0 + if cut_low is not None: + trace[:,:cut_low] = 1 + if cut_hi is not None: + trace[:,cut_hi:] = 1 + return trace.astype(bool) + + trace_o1 = mask_from_trace(trace_profile[0], cut_low=0, cut_hi=199) + trace_o2 = mask_from_trace(trace_profile[1], cut_low=0, cut_hi=175) + return [trace_o1, trace_o2] + @pytest.fixture(scope="module") -def mask_trace_profile(wave_map): - pass +def detector_mask(wave_map): + """Add a few random bad pixels""" + shp = wave_map[0].shape + rng = np.random.default_rng(42) + mask = np.zeros(shp, dtype=bool) + bad = rng.choice(mask.size, 100) + bad = np.unravel_index(bad, shp) + mask[bad] = 1 + return mask + + +def test_extraction_engine(wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + detector_mask): + + engine = atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + global_mask=detector_mask) + shp = wave_map[0].shape + + # test wave_grid became unique + assert engine.wave_grid.dtype == np.float64 + unq = np.unique(wave_grid) + assert np.allclose(engine.wave_grid, unq) + assert engine.n_wavepoints == unq.size + + for order in [0,1]: + # test assignment of attributes and conversion to expected float64 dtype + assert engine.wave_map[order].dtype == np.float64 + assert engine.trace_profile[order].dtype == np.float64 + assert engine.mask_trace_profile[order].dtype == np.bool_ + + assert np.allclose(engine.wave_map[order], wave_map[order]) + assert np.allclose(engine.trace_profile[order], trace_profile[order]) + assert np.allclose(engine.mask_trace_profile[order], mask_trace_profile[order]) + + # test derived attributes + assert engine.data_shape == shp + assert engine.n_orders == 2 + + # test wave_p and wave_m. separate unit test for their calculation + for att in ["wave_p", "wave_m"]: + wave = getattr(engine, att) + assert wave.dtype == np.float64 + assert wave.shape == (2,)+shp + + # test masks + # ensure they all include the input detector bad pixel mask + for mask in [engine.mask_ord[0], engine.mask_ord[1], engine.mask, engine.general_mask]: + assert mask.dtype == np.bool_ + assert mask.shape == shp + assert np.all(mask[detector_mask == 1]) + # general_mask should be a copy of mask except it doesn't respect wavelength bounds + #assert np.allclose(engine.mask, engine.general_mask) + # order masks should include mask of trace + for order in [0,1]: + mask = engine.mask_ord[order] + mask_profile = mask_trace_profile[order] + import matplotlib.pyplot as plt + fig, (ax0, ax1) = plt.subplots(2,1,figsize = (10,10)) + ax0.imshow(mask, origin='lower') + ax1.imshow(mask_profile, origin='lower') + plt.show() + #assert np.all(mask[mask_profile == 1]) + + -def test_extraction_engine(): - pass diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index 126fa4c368..b0ee7aa57b 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -104,6 +104,8 @@ def test_get_wave_p_or_m(wave_map, dispersion_axis): """ Check that the plus and minus side is correctly identified for strictly ascending and strictly descending wavelengths. + + TODO: test that given a float64, output is float64 """ wave_reverse = np.fliplr(wave_map) if dispersion_axis == 0: From 7742e62f356308e94c3ad8f49df2e246f99a7972 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 26 Nov 2024 17:41:24 -0500 Subject: [PATCH 21/35] add more tests to extraction engine --- jwst/extract_1d/soss_extract/atoca.py | 8 +- .../soss_extract/tests/test_atoca.py | 138 +++++++++++++----- 2 files changed, 102 insertions(+), 44 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index cf4e248c67..02b1bc56a3 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -143,8 +143,6 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, # Estimate a global mask and masks for each orders self.mask, self.mask_ord = self._get_masks(global_mask) - # Save mask here as the general mask, since `mask` attribute can be changed. - self.general_mask = self.mask.copy() # Ensure there are adequate good pixels left in each order good_pixels_in_order = np.sum(np.sum(~self.mask_ord, axis=-1), axis=-1) @@ -163,6 +161,8 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, # Re-build global mask and masks for each orders self.mask, self.mask_ord = self._get_masks(global_mask) + # Save mask here as the general mask, since `mask` attribute can be changed. + self.general_mask = self.mask.copy() # turn kernels into sparse matrix self.kernels = self._create_kernels(kernels, c_kwargs) @@ -235,7 +235,7 @@ def update_throughput(self, throughput): raise ValueError(msg) # Set the attribute to the new values. - self.throughput = np.array(throughput_new).astype(self.dtype) + self.throughput = np.array(throughput_new, dtype=self.dtype) def _create_kernels(self, kernels, c_kwargs): @@ -361,7 +361,7 @@ def _get_i_bnds(self): # Take the most restrictive bound a = np.maximum(a, i_bnds[0]) b = np.minimum(b, i_bnds[1]) - i_bnds_new.append([a, b]) + i_bnds_new.append([int(a), int(b)]) return i_bnds_new diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index fd1aa0f360..7738820a16 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -4,16 +4,21 @@ from jwst.extract_1d.soss_extract import atoca +DATA_SHAPE = (25,200) +WAVE_BNDS_O1 = [2.8, 0.8] +WAVE_BNDS_O2 = [1.4, 0.5] +WAVE_BNDS_GRID = [0.7, 2.7] + @pytest.fixture(scope="module") def wave_map(): """Trying for a roughly factor-of-10 shrink on each axis but otherwise relatively faithful reproduction of real detector behavior""" - shp = (25, 200) - wave_ord1 = np.linspace(2.8, 0.8, shp[1]) - wave_ord1 = np.ones(shp)*wave_ord1[np.newaxis, :] + wave_ord1 = np.linspace(WAVE_BNDS_O1[0], WAVE_BNDS_O1[1], DATA_SHAPE[1]) + wave_ord1 = np.ones(DATA_SHAPE)*wave_ord1[np.newaxis, :] - wave_ord2 = np.linspace(1.4, 0.5, shp[1]) - wave_ord2 = np.ones(shp)*wave_ord2[np.newaxis,:] + wave_ord2 = np.linspace(WAVE_BNDS_O2[0], WAVE_BNDS_O2[1], DATA_SHAPE[1]) + wave_ord2 = np.ones(DATA_SHAPE)*wave_ord2[np.newaxis,:] + # add a small region of zeros to mimic what is input to the step from ref files wave_ord2[:,190:] = 0.0 return [wave_ord1, wave_ord2] @@ -25,15 +30,15 @@ def trace_profile(wave_map): give order 2 some slope to simulate that """ # order 1 - shp = wave_map[0].shape - ord1 = np.zeros((shp[0])) + DATA_SHAPE = wave_map[0].shape + ord1 = np.zeros((DATA_SHAPE[0])) ord1[3:9] = 0.2 ord1[2] = 0.1 ord1[9] = 0.1 - profile_ord1 = np.ones(shp)*ord1[:, np.newaxis] + profile_ord1 = np.ones(DATA_SHAPE)*ord1[:, np.newaxis] # order 2 - yy, xx = np.meshgrid(np.arange(shp[0]), np.arange(shp[1])) + yy, xx = np.meshgrid(np.arange(DATA_SHAPE[0]), np.arange(DATA_SHAPE[1])) yy = yy.astype(np.float32) - xx.astype(np.float32)*0.08 yy = yy.T @@ -53,9 +58,9 @@ def wave_grid(): """wave_grid has smaller spacings in some places than others and is not backwards order like the wave map Two duplicates are in there on purpose for testing""" - lo0 = np.linspace(0.7, 1.2, 6) - hi = np.linspace(1.2, 1.7, 19) - lo2 = np.linspace(1.7, 2.8, 12) + lo0 = np.linspace(WAVE_BNDS_GRID[0], 1.2, 16) + hi = np.linspace(1.2, 1.7, 46) + lo2 = np.linspace(1.7, WAVE_BNDS_GRID[1], 31) return np.concatenate([lo0, hi, lo2]) @@ -108,31 +113,42 @@ def mask_from_trace(trace_in, cut_low=None, cut_hi=None): @pytest.fixture(scope="module") def detector_mask(wave_map): """Add a few random bad pixels""" - shp = wave_map[0].shape rng = np.random.default_rng(42) - mask = np.zeros(shp, dtype=bool) + mask = np.zeros(DATA_SHAPE, dtype=bool) bad = rng.choice(mask.size, 100) - bad = np.unravel_index(bad, shp) + bad = np.unravel_index(bad, DATA_SHAPE) mask[bad] = 1 return mask -def test_extraction_engine(wave_map, - trace_profile, - throughput, - kernels, - wave_grid, - mask_trace_profile, - detector_mask): - - engine = atoca.ExtractionEngine(wave_map, +@pytest.fixture(scope="module") +def engine(wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + detector_mask, +): + return atoca.ExtractionEngine(wave_map, trace_profile, throughput, kernels, wave_grid, mask_trace_profile, global_mask=detector_mask) - shp = wave_map[0].shape + + +def test_extraction_engine(wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + detector_mask, + engine, +): + """Test the init of the engine with default/good inputs""" # test wave_grid became unique assert engine.wave_grid.dtype == np.float64 @@ -151,34 +167,76 @@ def test_extraction_engine(wave_map, assert np.allclose(engine.mask_trace_profile[order], mask_trace_profile[order]) # test derived attributes - assert engine.data_shape == shp + assert engine.data_shape == DATA_SHAPE assert engine.n_orders == 2 # test wave_p and wave_m. separate unit test for their calculation for att in ["wave_p", "wave_m"]: wave = getattr(engine, att) assert wave.dtype == np.float64 - assert wave.shape == (2,)+shp + assert wave.shape == (2,)+DATA_SHAPE + + # test _get_i_bounds + assert len(engine.i_bounds) == 2 + for order in [0,1]: + assert len(engine.i_bounds[order]) == 2 + assert engine.i_bounds[order][0] >= 0 + assert engine.i_bounds[order][1] < DATA_SHAPE[1] + assert all([isinstance(val, int) for val in engine.i_bounds[order]]) + + # test to ensure that wave_map is considered for bounds + # in order 1 the wave_map is more restrictive on the shortwave end + # in order 2 the wave_grid is more restrictive on the longwave end + # in order 1 no restriction on the longwave end so get the full extent + # in order 2 no restriction on the shortwave end so get the full extent + assert engine.i_bounds[0][0] > 0 + assert engine.i_bounds[1][1] < engine.n_wavepoints + # TODO: off-by-one error here. why does this fail? + # check what this looks like on a real run on main + # assert engine.i_bounds[0][1] == engine.n_wavepoints + # assert engine.i_bounds[1][0] == 0 + # test masks # ensure they all include the input detector bad pixel mask for mask in [engine.mask_ord[0], engine.mask_ord[1], engine.mask, engine.general_mask]: assert mask.dtype == np.bool_ - assert mask.shape == shp + assert mask.shape == DATA_SHAPE assert np.all(mask[detector_mask == 1]) - # general_mask should be a copy of mask except it doesn't respect wavelength bounds - #assert np.allclose(engine.mask, engine.general_mask) - # order masks should include mask of trace - for order in [0,1]: - mask = engine.mask_ord[order] - mask_profile = mask_trace_profile[order] - import matplotlib.pyplot as plt - fig, (ax0, ax1) = plt.subplots(2,1,figsize = (10,10)) - ax0.imshow(mask, origin='lower') - ax1.imshow(mask_profile, origin='lower') - plt.show() - #assert np.all(mask[mask_profile == 1]) + # general_mask should be a copy of mask + assert np.allclose(engine.mask, engine.general_mask) + wave_bnds = [WAVE_BNDS_O1, WAVE_BNDS_O2] + for order in [0,1]: + # ensure wavelength bounds from wave_grid are respected in mask_ord + mask = engine.mask_ord[order] + wls = wave_map[order] + lo, hi = wave_bnds[order] + outside = (wls > lo) | (wls < hi) + assert np.all(mask[outside]) + + # a bit paradoxically, engine.mask_ord does not contain a single order's mask_trace_profile + # instead, it's mask_trace_profile[0] AND mask_trace_profile[1], i.e., + # the trace profiles of BOTH orders are UNmasked in the masks of each order + # the general mask and wavelength bounds are then applied, so the + # only difference between mask_ord[0] and mask_ord[1] are the wavelength bounds + # test that NOT all locations masked by mask_trace_profile are masked in mask_ord + assert not np.all(mask[mask_trace_profile[order]]) + # test that all locations masked by both profiles are masked in mask_ord + combined_profile = mask_trace_profile[0] & mask_trace_profile[1] + assert np.all(mask[combined_profile]) + + # test throughput function conversion to array + for order in [0,1]: + thru = engine.throughput[order] + assert thru.size == engine.n_wavepoints + assert np.all(thru >= 0) + assert thru.dtype == np.float64 + + # test kernel is cast to proper shape for input (trivial) kernel + print(engine.kernels[0].shape) + print(engine.kernels[0].dtype) + print(engine.kernels[0].toarray()) From 53607a6480d3e73cfc4f8b14ebc55af8b22750a5 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 27 Nov 2024 15:49:02 -0500 Subject: [PATCH 22/35] more unit tests for extraction engine methods --- jwst/extract_1d/soss_extract/atoca.py | 30 ++-- jwst/extract_1d/soss_extract/soss_extract.py | 2 +- .../soss_extract/tests/test_atoca.py | 135 ++++++++++++++++-- 3 files changed, 136 insertions(+), 31 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 02b1bc56a3..4c7f1418f7 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -133,8 +133,8 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, self.orders = orders self.n_orders = len(self.orders) if self.n_orders != len(self.wave_map): - msg = ("The number of orders specified {} and the number of " - "wavelength maps provided {} do not match.") + msg = ("The number of orders specified ({}) and the number of " + "wavelength maps provided ({}) do not match.") log.critical(msg.format(self.n_orders, len(self.wave_map))) raise ValueError(msg.format(self.n_orders, len(self.wave_map))) @@ -170,9 +170,7 @@ def __init__(self, wave_map, trace_profile, throughput, kernels, # Compute integration weights. see method self.get_w() for details. self.weights, self.weights_k_idx = self.compute_weights() - # Initialize the pixel mapping (b_n) matrices self.pixel_mapping = [None for _ in range(self.n_orders)] - self.i_grid = None self.tikho_mat = None self.w_t_wave_c = None @@ -214,27 +212,22 @@ def update_throughput(self, throughput): Throughput values for each order, given either as an array or as a callable function with self.wave_grid as input. """ - - # Update the throughput values. throughput_new = [] for throughput_n in throughput: # Loop over orders. if callable(throughput_n): + throughput_n = throughput_n(self.wave_grid) - # Throughput was given as a callable function. - throughput_new.append(throughput_n(self.wave_grid)) - - elif throughput_n.shape == self.wave_grid.shape: - - # Throughput was given as an array. - throughput_new.append(throughput_n) - - else: - msg = 'Throughputs must be given as callable or arrays matching the extraction grid.' + msg = 'Throughputs must be given as callable or arrays matching the extraction grid.' + if not isinstance(throughput_n, np.ndarray): + log.critical(msg) + raise ValueError(msg) + if throughput_n.shape != self.wave_grid.shape: log.critical(msg) raise ValueError(msg) - # Set the attribute to the new values. + throughput_new.append(throughput_n) + self.throughput = np.array(throughput_new, dtype=self.dtype) @@ -660,7 +653,7 @@ def get_tikho_tests(self, factors, data, error): # Build the system to solve b_matrix, pix_array = self.get_detector_model(data, error) - tikho = atoca_utils.Tikhonov(b_matrix, pix_array, self.t_mat) + tikho = atoca_utils.Tikhonov(b_matrix, pix_array, self.tikho_mat) # Test all factors tests = tikho.test_factors(factors) @@ -959,6 +952,7 @@ def get_w(self, i_order): """Compute integration weights for each grid points and each pixels. Depends on the order `n`. TODO: what is this doing? where can we find the math? + TODO: why is the mask not just simply mask_ord? Parameters ---------- diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 16e6d46bc5..389b834628 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -677,7 +677,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, except MaskOverlapError: log.error('Not enough unmasked pixels to model the remaining part of order 2.' - 'Model and spectrum will be NaN in that spectral region.') + ' Model and spectrum will be NaN in that spectral region.') spec_ord = [_build_null_spec_table(pixel_wave_grid)] model = np.nan * np.ones_like(scidata_bkg) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index 7738820a16..e3f676f3fc 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -87,7 +87,7 @@ def filter_function(wl, wl_max): @pytest.fixture(scope="module") def kernels(): """For now, just return unity""" - return [np.array([1.,])] + return [np.array([1.,]), np.array([1.,])] @@ -139,14 +139,15 @@ def engine(wave_map, global_mask=detector_mask) -def test_extraction_engine(wave_map, - trace_profile, - throughput, - kernels, - wave_grid, - mask_trace_profile, - detector_mask, - engine, +def test_extraction_engine( + wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + detector_mask, + engine, ): """Test the init of the engine with default/good inputs""" @@ -160,6 +161,7 @@ def test_extraction_engine(wave_map, # test assignment of attributes and conversion to expected float64 dtype assert engine.wave_map[order].dtype == np.float64 assert engine.trace_profile[order].dtype == np.float64 + assert engine.kernels[order].dtype == np.float64 assert engine.mask_trace_profile[order].dtype == np.bool_ assert np.allclose(engine.wave_map[order], wave_map[order]) @@ -235,8 +237,117 @@ def test_extraction_engine(wave_map, assert thru.dtype == np.float64 # test kernel is cast to proper shape for input (trivial) kernel - print(engine.kernels[0].shape) - print(engine.kernels[0].dtype) - print(engine.kernels[0].toarray()) + for order in [0,1]: + n_valid = engine.i_bounds[order][1] - engine.i_bounds[order][0] + expected_shape = (n_valid, engine.n_wavepoints) + assert engine.kernels[order].shape == expected_shape + # for trivial kernel only one element per row is nonzero + assert engine.kernels[order].count_nonzero() == expected_shape[0] + + # test weights. see separate unit tests to ensure the calculation is correct + for order in [0,1]: + n_valid = engine.i_bounds[order][1] - engine.i_bounds[order][0] + weights = engine.weights[order] + k_idx = engine.weights_k_idx[order] + assert weights.dtype == np.float64 + assert np.issubdtype(k_idx.dtype, np.integer) + + # TODO: why is weights the same size as ~engine.mask, and not ~engine.mask_ord? + assert weights.shape == (np.sum(~engine.mask), n_valid) + assert k_idx.shape[0] == weights.shape[0] + + # test assignment of empty attributes + for att in ["w_t_wave_c", "tikho_mat", "_tikho_mat"]: + assert hasattr(engine, att) + assert getattr(engine, "pixel_mapping") == [None, None] + + +def test_extraction_engine_bad_inputs( + wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + detector_mask, +): + # not enough good pixels in order + with pytest.raises(atoca.MaskOverlapError): + detector_mask = np.ones_like(detector_mask) + detector_mask[5:7,50:55] = 0 #still a few good pixels but very few + atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + global_mask=detector_mask) + + # wrong number of orders + with pytest.raises(ValueError): + atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels, + wave_grid, + mask_trace_profile, + global_mask=detector_mask, + orders=[0,]) + + +def test_get_attributes(engine): + # test string input + assert np.allclose(engine.get_attributes("wave_map"), engine.wave_map) + + # test list of strings input + name_list = ["wave_map", "wave_grid"] + att_list = engine.get_attributes(*name_list) + expected = [engine.wave_map, engine.wave_grid] + for i in range(len(expected)): + for j in range(2): #orders + print(att_list[i][j]) + assert np.allclose(att_list[i][j], expected[i][j]) + + # test i_order not None + att_list = engine.get_attributes(*name_list, i_order=1) + expected = [engine.wave_map[1], engine.wave_grid[1]] + for i in range(len(expected)): + assert np.allclose(att_list[i], expected[i]) + + +def test_update_throughput(engine, throughput): + + old_thru = engine.throughput + + # test callable input + new_thru = [throughput[1], throughput[1]] + engine.update_throughput(new_thru) + for i, thru in enumerate(engine.throughput): + assert isinstance(thru, np.ndarray) + assert thru.shape == engine.wave_grid.shape + assert np.allclose(engine.throughput[0], engine.throughput[1]) + + # test array input + # first reset to old throughput + engine.throughput = old_thru + new_thru = [thru*2 for thru in engine.throughput] + engine.update_throughput(new_thru) + for i, thru in enumerate(engine.throughput): + assert np.allclose(thru, old_thru[i]*2) + + # test fail on bad array shape + new_thru = [thru[:-1] for thru in engine.throughput] + with pytest.raises(ValueError): + engine.update_throughput(new_thru) + + # test fail on callable that doesn't return correct array shape + def new_thru_f(wl): + return 1.0 + with pytest.raises(ValueError): + engine.update_throughput([new_thru_f, new_thru_f]) + print(engine.throughput) + + + From bed22ac052f68797835e7cbd1c5078d60528ff3a Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 29 Nov 2024 12:16:57 -0500 Subject: [PATCH 23/35] more unit tests of extraction engine --- jwst/extract_1d/soss_extract/atoca.py | 8 ++- .../soss_extract/tests/test_atoca.py | 67 ++++++++++++++++++- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 4c7f1418f7..0630a44377 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -360,7 +360,8 @@ def _get_i_bnds(self): def wave_grid_c(self, i_order): - """Return wave_grid for the convolved flux at a given order. + """Return wave_grid for a given order constrained according to the i_bounds + of that order. """ index = slice(*self.i_bounds[i_order]) @@ -400,7 +401,7 @@ def compute_weights(self): return weights, weights_k_idx def _set_w_t_wave_c(self, i_order, product): - """Save the matrix product of the weighs (w), the throughput (t), + """Save the matrix product of the weights (w), the throughput (t), the wavelength (lam) and the convolution matrix for faster computation. """ @@ -453,6 +454,9 @@ def get_pixel_mapping(self, i_order, error=None, quick=False): array[float] Sparse matrix of b_n coefficients """ + if (quick) and (self.w_t_wave_c is None): + msg = "Attribute w_t_wave_c of ExtractionEngine must exist if quick=True" + raise AttributeError(msg) # Special treatment for error map # Can be bool or array. diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index e3f676f3fc..72c37dcb33 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -199,7 +199,7 @@ def test_extraction_engine( # assert engine.i_bounds[1][0] == 0 - # test masks + # test _get_masks # ensure they all include the input detector bad pixel mask for mask in [engine.mask_ord[0], engine.mask_ord[1], engine.mask, engine.general_mask]: assert mask.dtype == np.bool_ @@ -305,7 +305,6 @@ def test_get_attributes(engine): expected = [engine.wave_map, engine.wave_grid] for i in range(len(expected)): for j in range(2): #orders - print(att_list[i][j]) assert np.allclose(att_list[i][j], expected[i][j]) # test i_order not None @@ -345,9 +344,71 @@ def new_thru_f(wl): return 1.0 with pytest.raises(ValueError): engine.update_throughput([new_thru_f, new_thru_f]) - print(engine.throughput) +def test_create_kernels(engine): + # TODO: make a non-trivial kernel fixture to work with this + # trivial example already done + pass +def test_wave_grid_c(engine): + for order in [0,1]: + n_valid = engine.i_bounds[order][1] - engine.i_bounds[order][0] + assert engine.wave_grid_c(order).size == n_valid + + +def test_set_w_t_wave_c(engine): + """all this does is copy whatever is input""" + product = np.zeros((1,)) + engine._set_w_t_wave_c(0, product) + assert len(engine.w_t_wave_c) == engine.n_orders + assert engine.w_t_wave_c[0] == product + assert engine.w_t_wave_c[1] == [] + assert product is not engine.w_t_wave_c[0] + + +def test_grid_from_map(): + assert False + + +def test_get_pixel_mapping(engine): + + pixel_mapping_0 = engine.get_pixel_mapping(0) + # check attribute is set and identical to output + # check the second one is not set but there is space for it + assert hasattr(engine, "pixel_mapping") + assert len(engine.pixel_mapping) == engine.n_orders + assert np.allclose(engine.pixel_mapping[0].data, pixel_mapping_0.data) + assert engine.pixel_mapping[1] is None + + # set the second one so can check both at once + engine.get_pixel_mapping(1) + for order in [0,1]: + mapping = engine.pixel_mapping[order] + assert mapping.dtype == np.float64 + # TODO: why is this the shape, instead of using mask_ord and only the valid wave_grid? + expected_shape = (np.sum(~engine.mask), engine.wave_grid.size) + assert mapping.shape == expected_shape + + # test that w_t_wave_c is getting saved + w_t_wave_c = engine.w_t_wave_c[order] + assert w_t_wave_c.dtype == np.float64 + assert w_t_wave_c.shape == expected_shape + + # check if quick=True works + mapping_quick = engine.get_pixel_mapping(order, quick=True) + assert np.allclose(mapping.data, mapping_quick.data) + + # check that quick=True does not work if w_t_wave_c unsaved + engine.w_t_wave_c = None + with pytest.raises(AttributeError): + engine.get_pixel_mapping(1, quick=True) + + # TODO: follow the shape of every sparse matrix through this function, see if + # it makes sense or if it would make more sense to mask differently + # TODO: test that the math is actually correct using a trivial example (somehow) + +def test_build_sys(): + pass \ No newline at end of file From cddde89d5656139f53e2fa9a714a9125ad5ae24d Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 2 Dec 2024 17:50:18 -0500 Subject: [PATCH 24/35] toy model for tests is round-tripping successfully now --- jwst/extract_1d/soss_extract/atoca.py | 16 +- jwst/extract_1d/soss_extract/atoca_utils.py | 15 +- .../soss_extract/tests/test_atoca.py | 175 +++++++++++++++++- 3 files changed, 186 insertions(+), 20 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 0630a44377..c3b4d6d934 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -540,9 +540,6 @@ def build_sys(self, data, error): def get_detector_model(self, data, error): """ Get the linear model of the detector pixel, B.dot(flux) = pixels - TIPS: To be quicker, only specify the psf (`p_list`) in kwargs. - There will be only one matrix multiplication: - (P/sig).(w.T.lambda.c_n). Parameters ---------- @@ -611,8 +608,8 @@ def estimate_tikho_factors(self, flux_estimate): Returns ------- - array[float] - Grid of Tikhonov factors. + float + Estimated Tikhonov factor. """ # Get some values from the object mask, wave_grid = self.get_attributes('mask', 'wave_grid') @@ -670,7 +667,6 @@ def get_tikho_tests(self, factors, data, error): def best_tikho_factor(self, tests, fit_mode): """ - TODO: why does function with same name exist here and in atoca_utils? Compute the best scale factor for Tikhonov regularization. It is determined by taking the factor giving the lowest reduced chi2 on the detector, the highest curvature of the l-curve or when the improvement @@ -722,7 +718,7 @@ def best_tikho_factor(self, tests, fit_mode): # Evaluate best factor with different methods results = dict() for mode in list_mode: - best_fac = tests.best_tikho_factor(mode=mode) + best_fac = tests.best_factor(mode=mode) results[mode] = best_fac if fit_mode == 'all': @@ -873,14 +869,14 @@ def __call__(self, data, error, tikhonov=False, factor=None): """ # Solve with the specified solver. if tikhonov: - # Build the system to solve - b_matrix, pix_array = self.get_detector_model(data, error) - if factor is None: msg = "Please specify tikhonov `factor`." log.critical(msg) raise ValueError(msg) + # Build the system to solve + b_matrix, pix_array = self.get_detector_model(data, error) + spectrum = self._solve_tikho(b_matrix, pix_array, self.tikho_mat, factor=factor) else: diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 152ad8b80c..108dc360d3 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1729,11 +1729,8 @@ def _compute_curvature(self): np.log10(reg2)) - def best_tikho_factor(self, mode='curvature'): + def best_factor(self, mode='curvature'): """ - TODO: why is there a function with identical name in atoca.py ExtractionEngine? - this one is called by the other one... - Compute the best scale factor for Tikhonov regularisation. It is determined by taking the factor giving the highest logL on the detector or the highest curvature of the l-curve, @@ -1806,7 +1803,7 @@ def best_tikho_factor(self, mode='curvature'): else: msg = (f'`mode`={mode} is not a valid option for ' - f'TikhoTests.best_tikho_factor().') + f'TikhoTests.best_factor().') log.critical(msg) raise ValueError(msg) @@ -1860,8 +1857,6 @@ def __init__(self, a_mat, b_vec, t_mat): self.t_mat_2 = (t_mat.T).dot(t_mat) # squared tikhonov matrix self.a_mat_2 = a_mat.T.dot(a_mat) # squared model matrix self.result = (a_mat.T).dot(b_vec.T) - # here self.result is all NaNs... why? - print("self.result is all nans", np.any(~np.isnan(self.result.toarray()))) self.idx_valid = (self.result.toarray() != 0).squeeze() # valid indices to use if `valid` is True # Save other attributes @@ -1933,13 +1928,17 @@ def test_factors(self, factors): sln, err, reg = [], [], [] # Test all factors + # TODO: does this need to be in a for loop? is it possible to allow solve + # to do a single large matrix multiplication? for i_fac, factor in enumerate(factors): # Save solution sln.append(self.solve(factor)) # Save error A.x - b - err.append(a_mat.dot(sln[-1]) - b_vec) + this_err = a_mat.dot(sln[-1]) - b_vec + # initially this is a np.matrix of shape (1, n_pixels); flatten and make array + err.append(np.array(this_err).flatten()) # Save regularization term reg_i = t_mat.dot(sln[-1]) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index 72c37dcb33..718f9ddc52 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -8,6 +8,9 @@ WAVE_BNDS_O1 = [2.8, 0.8] WAVE_BNDS_O2 = [1.4, 0.5] WAVE_BNDS_GRID = [0.7, 2.7] +ORDER1_SCALING = 20.0 +ORDER2_SCALING = 2.0 +SPECTRAL_SLOPE = 2 @pytest.fixture(scope="module") def wave_map(): @@ -23,6 +26,7 @@ def wave_map(): return [wave_ord1, wave_ord2] + @pytest.fixture(scope="module") def trace_profile(wave_map): """ @@ -410,5 +414,172 @@ def test_get_pixel_mapping(engine): # TODO: test that the math is actually correct using a trivial example (somehow) -def test_build_sys(): - pass \ No newline at end of file +def test_rebuild(engine): + + detector_model = engine.rebuild(f_lam) + assert detector_model.dtype == np.float64 + assert detector_model.shape == engine.wave_map[0].shape + + # test that input spectrum is ok as either callable or array + assert np.allclose(detector_model, engine.rebuild(f_lam(engine.wave_grid))) + + # test fill value + detector_model_nans = engine.rebuild(f_lam, fill_value=np.nan) + assert np.allclose(np.isnan(detector_model_nans), engine.general_mask) + + +def f_lam(wl, m=SPECTRAL_SLOPE, b=0): + """ + Estimator for flux as function of wavelength + Returns linear function of wl with slope m and intercept b + + This function is also used in this test suite as + """ + return m*wl + b + + +@pytest.fixture(scope="module") +def imagemodel(engine, detector_mask): + """ + use engine.rebuild to make an image model from an expected f(lambda). + Then we can ensure it round-trips + """ + + rng = np.random.default_rng(seed=42) + shp = engine.trace_profile[0].shape + + # make the detector bad values NaN, but leave the trace masks alone + # in reality, of course, this is backward: detector bad values + # would be determined from data + data = engine.rebuild(f_lam, fill_value=0.0) + data[detector_mask] = np.nan + + # add noise + noise_scaling = 3e-5 + data += noise_scaling*rng.standard_normal(shp) + + # error random, all positive, but also always larger than a certain value + # to avoid very large values of data / error + error = noise_scaling*(rng.standard_normal(shp)**2 + 0.5) + + return data, error + + +def test_build_sys(imagemodel, engine): + + data, error = imagemodel + matrix, result = engine.build_sys(data, error) + assert result.size == engine.n_wavepoints + assert matrix.shape == (result.size, result.size) + + + +def test_get_detector_model(imagemodel, engine): + + data, error = imagemodel + unmasked_size = np.sum(~engine.mask) + b_matrix, data_matrix = engine.get_detector_model(data, error) + + assert data_matrix.shape == (1, unmasked_size) + assert b_matrix.shape == (unmasked_size, engine.n_wavepoints) + assert np.allclose(data_matrix.toarray()[0], (data/error)[~engine.mask]) + + +def test_estimate_tikho_factors(engine): + + factor = engine.estimate_tikho_factors(f_lam) + assert isinstance(factor, float) + + + # very approximate calculation of tik fac looks like + # n_pixels = (~engine.mask).sum() + # flux = f_lam(engine.wave_grid) + # dlam = engine.wave_grid[1:] - engine.wave_grid[:-1] + # print(n_pixels/np.mean(flux[1:] * dlam)) + + +@pytest.fixture(scope="module") +def tikho_tests(imagemodel, engine): + data, error = imagemodel + + log_guess = np.log10(engine.estimate_tikho_factors(f_lam)) + factors = np.logspace(log_guess - 9, log_guess + 9, 19) + return factors, engine.get_tikho_tests(factors, data, error) + + +def test_get_tikho_tests(tikho_tests, engine): + + factors, tests = tikho_tests + unmasked_size = np.sum(~engine.mask) + + # test all the output shapes + assert np.allclose(tests["factors"], factors) + assert tests["solution"].shape == (len(factors), engine.n_wavepoints) + assert tests["error"].shape == (len(factors), unmasked_size) + assert tests["reg"].shape == (len(factors), engine.n_wavepoints-1) + assert tests["chi2"].shape == (len(factors),) + assert tests["chi2_soft_l1"].shape == (len(factors),) + assert tests["chi2_cauchy"].shape == (len(factors),) + assert np.allclose(tests["grid"], engine.wave_grid) + + # test data type is preserved through solve + for key in tests.keys(): + assert tests[key].dtype == np.dtype("float64") + + +def test_best_tikho_factor(engine, tikho_tests): + + input_factors, tests = tikho_tests + fit_modes = ["all", "curvature", "chi2", "d_chi2"] + best_factors = [] + for mode in fit_modes: + factor = engine.best_tikho_factor(tests, mode) + assert isinstance(factor, float) + best_factors.append(factor) + + # ensure fit_mode=all found one of the three others + assert best_factors[0] in best_factors[1:] + + # TODO: test the logic tree by manually changing the tests dict + # this is non-trivial because dchi2 and curvature + # are both derived from other keys in the dictionary, not just the statistical metrics + + +def test_call(engine, tikho_tests, imagemodel): + """ + Run the actual extract method. + Ensure it can retrieve the input spectrum based on f_lam to within a few percent + at all points on the wave_grid. + + Note this round-trip implicitly checks the math of the build_sys, get_detector_model, + _solve, and _solve_tikho, at least at first blush. + """ + data, error = imagemodel + _, tests = tikho_tests + best_factor = engine.best_tikho_factor(tests, "all") + + expected_spectrum = f_lam(engine.wave_grid) + for tikhonov in [True, False]: + spectrum = engine(data, error, tikhonov=tikhonov, factor=best_factor) + diff = (spectrum - expected_spectrum)/expected_spectrum + assert not np.all(np.isnan(diff)) + diff = diff[~np.isnan(diff)] + assert np.all(np.abs(diff) < 0.05) + + # test bad input, failing to put factor in for Tikhonov solver + with pytest.raises(ValueError): + engine(data, error, tikhonov=True) + + +def test_compute_likelihood(engine, imagemodel): + """Ensure log-likelihood is highest for the correct slope""" + + data, error = imagemodel + test_slopes = np.arange(0, 5, 0.5) + logl = [] + for slope in test_slopes: + spectrum = partial(f_lam, m=slope) + logl.append(engine.compute_likelihood(spectrum, data, error)) + + assert np.argmax(logl) == np.argwhere(test_slopes == SPECTRAL_SLOPE) + assert np.all(np.array(logl) < 0) From 85e7fb9a851abb3dc38ae1cfe872c8e8bf2c8951 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 4 Dec 2024 09:22:53 -0500 Subject: [PATCH 25/35] move utils tests into atoca main --- jwst/extract_1d/soss_extract/soss_extract.py | 2 - .../soss_extract/tests/test_atoca.py | 542 +++++++++++++++++- .../soss_extract/tests/test_atoca_utils.py | 471 --------------- 3 files changed, 530 insertions(+), 485 deletions(-) diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 389b834628..16f4a73802 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -767,8 +767,6 @@ def _model_single_order(data_order, err_order, ref_file_args, mask_fit, def throughput(wavelength): return np.ones_like(wavelength) kernel = np.array([1.]) - - # Set reference file arguments ref_file_args[2] = [throughput] ref_file_args[3] = [kernel] diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index 718f9ddc52..8de18bd5ce 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -2,6 +2,26 @@ import numpy as np from functools import partial from jwst.extract_1d.soss_extract import atoca +from jwst.extract_1d.soss_extract import atoca_utils as au + + +""" +Tests for atoca.py and atoca_utils.py + +Create a miniature, slightly simplified model of the SOSS detector/optics. +Use those to instantiate an extraction engine, use the engine to create mock data +from a known input spectrum, and then check if the engine can retrieve that spectrum +from the data. + +The model has the following features: +- Factor-of-10 smaller along each dimension +- Similar wavelengths for each of the two orders +- Partially overlapping traces for the two orders +- Randomly-selected bad pixels in the data +- Wave grid of size ~100 with varying resolution +- Triangle function throughput for each spectral order +- Kernel set to unity (for now) +""" DATA_SHAPE = (25,200) @@ -12,6 +32,7 @@ ORDER2_SCALING = 2.0 SPECTRAL_SLOPE = 2 + @pytest.fixture(scope="module") def wave_map(): """Trying for a roughly factor-of-10 shrink on each axis @@ -89,11 +110,47 @@ def filter_function(wl, wl_max): return [thru_o1, thru_o2] @pytest.fixture(scope="module") -def kernels(): +def kernels_unity(): """For now, just return unity""" return [np.array([1.,]), np.array([1.,])] +@pytest.fixture(scope="module") +def webb_kernels(wave_map): + """ + Toy model of the JWST kernels. + Let the kernel be a triangle function along the pixel position axis, + peaking at the center, + and independent of wavelength. + + Same for both orders except the wavelengths and wave_trace are different. + """ + n_os = 5 + n_pix = 15 # full pixel width of kernel + n_wave = 10 + peakiness = 4 + min_val = 1 + kernel_width = n_os*n_pix - (n_os - 1) + ctr_idx = kernel_width//2 + + kernels = [] + for order in [0,1]: + # set up wavelength grid over which kernel is defined + wave_trace = wave_map[order][0] + wave_range = (np.min(wave_trace), np.max(wave_trace)) + wavelengths = np.linspace(*wave_range, n_wave) + wave_kernel = np.ones((kernel_width, wavelengths.size), dtype=float)*wavelengths[None,:] + + # model kernel as simply a triangle function that peaks at the center + triangle_function = ctr_idx - peakiness*np.abs(ctr_idx - np.arange(0, kernel_width)) + triangle_function[triangle_function<=min_val] = min_val + kernel = np.ones((kernel_width, wavelengths.size), dtype=float)*triangle_function[:,None] + kernel/=np.sum(kernel) + + kernels.append(au.WebbKernel(wave_kernel, kernel, wave_trace, n_pix)) + + return kernels + @pytest.fixture(scope="module") def mask_trace_profile(trace_profile): @@ -129,7 +186,7 @@ def detector_mask(wave_map): def engine(wave_map, trace_profile, throughput, - kernels, + kernels_unity, wave_grid, mask_trace_profile, detector_mask, @@ -137,17 +194,453 @@ def engine(wave_map, return atoca.ExtractionEngine(wave_map, trace_profile, throughput, - kernels, + kernels_unity, wave_grid, mask_trace_profile, global_mask=detector_mask) -def test_extraction_engine( +def test_arange_2d(): + + starts = np.array([3,4,5]) + stops = np.ones(starts.shape)*7 + out = au.arange_2d(starts, stops) + + bad = -1 + expected_out = np.array([ + [3,4,5,6], + [4,5,6,bad], + [5,6,bad,bad] + ]) + assert np.allclose(out, expected_out) + + # test bad input catches + starts_wrong_shape = starts[1:] + with pytest.raises(ValueError): + au.arange_2d(starts_wrong_shape, stops) + + stops_too_small = np.copy(stops) + stops_too_small[2] = 4 + with pytest.raises(ValueError): + au.arange_2d(starts, stops_too_small) + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wv_map_bounds(wave_map, dispersion_axis): + """ + top is the low-wavelength end, bottom is high-wavelength end + """ + wave_map = wave_map[0].copy() + wave_map[1,3] = -1 #test skip of bad value + wavelengths = wave_map[0] + + if dispersion_axis == 0: + wave_flip = wave_map.T + else: + wave_flip = wave_map + wave_top, wave_bottom = au._get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) + + # flip the results back so we can re-use the same tests + if dispersion_axis == 0: + wave_top = wave_top.T + wave_bottom = wave_bottom.T + + diff = (wavelengths[1:]-wavelengths[:-1])/2 + diff_lower = np.insert(diff,0,diff[0]) + diff_upper = np.append(diff,diff[-1]) + wave_top_expected = wavelengths-diff_lower + wave_bottom_expected = wavelengths+diff_upper + + # basic test + assert wave_top.shape == wave_bottom.shape == (wave_map.shape[0],)+wavelengths.shape + assert np.allclose(wave_top[0], wave_top_expected) + assert np.allclose(wave_bottom[0], wave_bottom_expected) + + # test skip bad pixel + assert wave_top[1,3] == 0 + assert wave_bottom[1,3] == 0 + + # test bad input error raises + with pytest.raises(ValueError): + au._get_wv_map_bounds(wave_flip, dispersion_axis=2) + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wave_p_or_m(wave_map, dispersion_axis): + """ + Check that the plus and minus side is correctly identified + for strictly ascending and strictly descending wavelengths. + """ + wave_map = wave_map[0].copy() + wave_reverse = np.fliplr(wave_map) + if dispersion_axis == 0: + wave_flip = wave_map.T + wave_reverse = wave_reverse.T + else: + wave_flip = wave_map + + wave_p_0, wave_m_0 = au.get_wave_p_or_m(wave_flip, dispersion_axis=dispersion_axis) + wave_p_1, wave_m_1 = au.get_wave_p_or_m(wave_reverse, dispersion_axis=dispersion_axis) + + if dispersion_axis==0: + wave_p_0 = wave_p_0.T + wave_m_0 = wave_m_0.T + wave_p_1 = wave_p_1.T + wave_m_1 = wave_m_1.T + assert np.all(wave_p_0 >= wave_m_0) + assert np.allclose(wave_p_0, np.fliplr(wave_p_1)) + assert np.allclose(wave_m_0, np.fliplr(wave_m_1)) + + +def test_get_wave_p_or_m_not_ascending(wave_map): + wave_map = wave_map[0].copy() + with pytest.raises(ValueError): + wave_map[0,5] = 2 # make it not strictly ascending + au.get_wave_p_or_m(wave_map, dispersion_axis=1) + + +FIBONACCI = np.array([1,1,2,3,5,8,13,21,35], dtype=float) +@pytest.mark.parametrize("n_os", [1,5]) +def test_oversample_grid(n_os): + + oversample = au.oversample_grid(FIBONACCI, n_os) + + # oversample_grid is supposed to remove any duplicates, and there is a duplicate + # in FIBONACCI. So the output should be 4 times the size of FIBONACCI minus 1 + assert oversample.size == n_os*(FIBONACCI.size - 1) - (n_os-1) + assert oversample.min() == FIBONACCI.min() + assert oversample.max() == FIBONACCI.max() + + # test whether np.interp could have been used instead + grid = np.arange(0, FIBONACCI.size, 1/n_os) + wls = np.unique(np.interp(grid, np.arange(FIBONACCI.size), FIBONACCI)) + assert np.allclose(oversample, wls) + + +@pytest.mark.parametrize("os_factor", [1,2,5]) +def test_oversample_irregular(os_factor): + """Test oversampling to a grid with irregular spacing""" + # oversampling function removes duplicates, + # this is tested in previous test, and just complicates counting for this test + # for FIBONACCI, unique is just removing zeroth element + fib_unq = np.unique(FIBONACCI) + n_os = np.ones((fib_unq.size-1,), dtype=int) + n_os[2:5] = os_factor + n_os[3] = os_factor*2 + # this gives n_os = [1 1 2 4 2] for os_factor = 2 + + oversample = au.oversample_grid(fib_unq, n_os) + + # test no oversampling was done on the elements where not requested + assert np.allclose(oversample[0:2], fib_unq[0:2]) + assert np.allclose(oversample[-1:], fib_unq[-1:]) + + # test output shape. + assert oversample.size == np.sum(n_os)+1 + + # test that this could have been done easily with np.interp + intervals = 1/n_os + intervals = np.insert(np.repeat(intervals, n_os),0,0) + grid = np.cumsum(intervals) + wls = np.interp(grid, np.arange(fib_unq.size), fib_unq) + assert wls.size == oversample.size + assert np.allclose(oversample, wls) + + # test that n_os shape must match input shape - 1 + with pytest.raises(ValueError): + au.oversample_grid(fib_unq, n_os[:-1]) + + +WAVELENGTHS = np.linspace(1.5, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) +@pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) +def test_extrapolate_grid(wave_range): + + extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) + + assert extrapolated.max() > wave_range[1] + assert extrapolated.min() < wave_range[0] + assert np.all(extrapolated[1:] >= extrapolated[:-1]) + + # if interpolation not needed on either side, should return the original + if wave_range == (2.1, 3.9): + assert extrapolated is WAVELENGTHS + + +def test_extrapolate_catch_failed_converge(): + # give wavelengths some non-linearity + wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 + with pytest.raises(RuntimeError): + au._extrapolate_grid(WAVELENGTHS, wave_range, 1) + + +def test_extrapolate_bad_inputs(): + with pytest.raises(ValueError): + au._extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) + with pytest.raises(ValueError): + au._extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) + + +def test_grid_from_map(wave_map, trace_profile): + """Covers expected behavior of grid_from_map, including coverage of a previous bug + where bad wavelengths were not being ignored properly""" + + wave_map = wave_map[0].copy() + wavelengths = wave_map[0][::-1] + trace_profile = trace_profile[0].copy() + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=None) + + assert np.allclose(wave_grid, wavelengths) + + # test custom wave_range + wave_range = [wavelengths[2], wavelengths[-2]+0.01] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert np.allclose(wave_grid, wavelengths[2:-1]) + + # test custom wave_range with extrapolation + wave_range = [wavelengths[2], wavelengths[-1]+1] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert len(wave_grid) > len(wavelengths[2:]) + n_inside = wavelengths[2:].size + assert np.allclose(wave_grid[:n_inside], wavelengths[2:]) + + with pytest.raises(ValueError): + au.grid_from_map(wave_map, trace_profile, wave_range=[0.1,0.2]) + + +def xsinx(x): + return x*np.sin(x) + + +def test_estim_integration_error(): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ + + n = 11 + grid = np.linspace(0, np.pi, n) + err, rel_err = au._estim_integration_err(grid, xsinx) + + assert len(rel_err) == n-1 + assert np.all(rel_err >= 0) + assert np.all(rel_err < 1) + + +@pytest.mark.parametrize("max_iter, rtol", [(1,1e-3), (10, 1e-9), (10, 1e-3), (1, 1e-9)]) +def test_adapt_grid(max_iter, rtol): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ + + input_grid = np.linspace(0, np.pi, 11) + input_grid_diff = input_grid[1] - input_grid[0] + max_grid_size = 100 + grid, is_converged = au._adapt_grid(input_grid, + xsinx, + max_grid_size, + max_iter=max_iter, + rtol=rtol) + + # ensure grid respects max_grid_size and max_iter in all cases + assert len(grid) <= max_grid_size + grid_diff = grid[1:] - grid[:-1] + assert np.min(grid_diff) >= input_grid_diff/(2**max_iter) + + numerical_integral = np.trapz(xsinx(grid), grid) + + # ensure this converges for at least one of our test cases + if max_iter == 10 and rtol == 1e-3: + assert is_converged + + if is_converged: + # test error of the answer is smaller than rtol + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + # test that success was a stop condition + assert len(grid) < max_grid_size + + # test stop conditions + elif max_iter == 10: + # ensure hitting max_grid_size returns an array of exactly length max_grid_size + assert len(grid) == max_grid_size + elif max_iter == 1: + # ensure hitting max_iter can stop iteration before max_grid_size reached + assert len(grid) <= 2*len(input_grid) + + +def test_adapt_grid_bad_inputs(): + with pytest.raises(ValueError): + # input grid larger than max_grid_size + au._adapt_grid(np.array([1,2,3]), xsinx, 2) + + +def test_trim_grids(): + + grid_range = (-3, 3) + grid0 = np.linspace(-3, 0, 4) # kept entirely. + grid1 = np.linspace(-3, 0, 16) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(0, 3, 5) # kept from 0 to 3 + grid3 = np.linspace(-4, 4, 5) # removed entirely. Outside of grid_range and the rest is superseded + + all_grids = [grid0, grid1, grid2, grid3] + trimmed_grids = au._trim_grids(all_grids, grid_range) + + assert len(trimmed_grids) == len(all_grids) + assert trimmed_grids[0].size == grid0.size + assert trimmed_grids[1].size == 0 + assert trimmed_grids[2].size == grid2.size + assert trimmed_grids[3].size == 0 + + +def test_make_combined_adaptive_grid(): + """see also tests of _adapt_grid and _trim_grids for more detailed tests""" + + grid_range = (0, np.pi) + grid0 = np.linspace(0, np.pi/2, 6) # kept entirely. + grid1 = np.linspace(0, np.pi/2, 15) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(np.pi/2, np.pi, 11) # kept from pi/2 to pi + + # purposely make same lower index for grid2 as upper index for grid0 to test uniqueness of output + + all_grids = [grid0, grid1, grid2] + all_estimate = [xsinx, xsinx, xsinx] + + rtol = 1e-3 + combined_grid = au.make_combined_adaptive_grid(all_grids, all_estimate, grid_range, + max_iter=10, rtol=rtol, max_total_size=100) + + numerical_integral = np.trapz(xsinx(combined_grid), combined_grid) + + assert np.unique(combined_grid).size == combined_grid.size + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + + +def test_throughput_soss(): + + wavelengths = np.linspace(2,5,10) + throughputs = np.ones_like(wavelengths) + interpolator = au.ThroughputSOSS(wavelengths, throughputs) + + # test that it returns 1 for all wavelengths inside range + interp = interpolator(wavelengths) + assert np.allclose(interp[1:-1], throughputs[1:-1]) + assert interp[0] == 0 + assert interp[-1] == 0 + + # test that it returns 0 for all wavelengths outside range + wavelengths_outside = np.linspace(1,1.5,5) + interp = interpolator(wavelengths_outside) + assert np.all(interp == 0) + + # test ValueError raise for shape mismatch + with pytest.raises(ValueError): + au.ThroughputSOSS(wavelengths, throughputs[:-1]) + + +def test_webb_kernel(webb_kernels, wave_map): + + wave_trace = wave_map[0][0] + min_trace, max_trace = np.min(wave_trace), np.max(wave_trace) + kern = webb_kernels[0] + + # basic ensure that the input is stored and shapes + assert kern.wave_kernels.shape == kern.kernels.shape + + # test that pixels is mirrored around the center and has zero at center + assert np.allclose(kern.pixels + kern.pixels[::-1], 0) + assert kern.pixels[kern.pixels.size//2] == 0 + + # test that wave_center has same shape as wavelength axis of wave_kernel + # but contains values that are in wave_trace + assert kern.wave_center.size == kern.wave_kernels.shape[1] + assert all(np.isin(kern.wave_center, wave_trace)) + + # test min value + assert kern.min_value > 0 + assert np.isin(kern.min_value, kern.kernels) + assert isinstance(kern.min_value, float) + + # test the polynomial fit has the proper shape. hard-coded to a first-order, i.e., linear fit + # since the throughput is constant in wavelength, the slopes should be close to zero + # and the y-intercepts should be close to kern.wave_center + # especially with so few points. just go with 10 percent, should catch egregious changes + assert kern.poly.shape == (kern.wave_kernels.shape[1], 2) + assert np.allclose(kern.poly[:,0], 0, atol=1e-1) + assert np.allclose(kern.poly[:,1], kern.wave_center, atol=1e-1) + + # test interpolation function, which takes in a pixel and a wavelength and returns a throughput + # this should return the triangle function at all wavelengths and zero outside range + pix_half = kern.n_pix//2 + wl_test = np.linspace(min_trace, max_trace, 10) + pixels_test = np.array([-pix_half-1, 0, pix_half, pix_half+1]) + + data_in = kern.kernels[:,0] + m = kern.min_value + expected = np.array([m, np.max(data_in), m, m]) + + interp = kern.f_ker(pixels_test, wl_test) + assert interp.shape == (pixels_test.size, wl_test.size) + diff = interp[:,1:] - interp[:,:-1] + assert np.allclose(diff, 0) + assert np.allclose(interp[:,0], expected, rtol=1e-3) + + # call the kernel object directly + # this takes a wavelength and a central wavelength of the kernel, + # then converts to pixels to use self.f_ker internally + assert kern(wl_test, wl_test).ndim == 1 + assert np.allclose(kern(wl_test, wl_test), np.max(data_in)) + + # both inputs need to be same shape + with pytest.raises(ValueError): + kern(wl_test, wl_test[:-1]) + + +def test_get_c_matrix(kernel): + + #TODO: what to use for grid? is it / can it be the same as wave_trace? + # I think it can be the same but does not need to be, and would be a better + # test if it were different, because the kernel is an interpolator that was + # created using wavelengths that are included in wave_trace. + matrix = au.get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5) + + # test with WebbKernel as the kernel + # ensure normalized + # ensure sparse + + + # test where input kernel is a 2-D array instead of callable + + + # test where input kernel is size 1 + + + # test where i_bounds is not None + + + # Test invalid kernel input (wrong dimensions) + + +def test_finite_first_diff(): + + wave_grid = np.linspace(0, 2*np.pi, 100) + test_0 = np.ones_like(wave_grid) + test_sin = np.sin(wave_grid) + + first_d = au.finite_first_d(wave_grid) + assert first_d.size == (wave_grid.size - 1)*2 + + # test trivial example returning zeros for constant + f0 = first_d.dot(test_0) + assert np.allclose(f0, 0) + + # test derivative of sin returns cos + wave_between = (wave_grid[1:] + wave_grid[:-1])/2 + f_sin = first_d.dot(test_sin) + assert np.allclose(f_sin, np.cos(wave_between), atol=1e-3) + + +def test_extraction_engine_init( wave_map, trace_profile, throughput, - kernels, wave_grid, mask_trace_profile, detector_mask, @@ -270,7 +763,7 @@ def test_extraction_engine_bad_inputs( wave_map, trace_profile, throughput, - kernels, + kernels_unity, wave_grid, mask_trace_profile, detector_mask, @@ -282,7 +775,7 @@ def test_extraction_engine_bad_inputs( atoca.ExtractionEngine(wave_map, trace_profile, throughput, - kernels, + kernels_unity, wave_grid, mask_trace_profile, global_mask=detector_mask) @@ -292,7 +785,7 @@ def test_extraction_engine_bad_inputs( atoca.ExtractionEngine(wave_map, trace_profile, throughput, - kernels, + kernels_unity, wave_grid, mask_trace_profile, global_mask=detector_mask, @@ -372,10 +865,6 @@ def test_set_w_t_wave_c(engine): assert product is not engine.w_t_wave_c[0] -def test_grid_from_map(): - assert False - - def test_get_pixel_mapping(engine): pixel_mapping_0 = engine.get_pixel_mapping(0) @@ -583,3 +1072,32 @@ def test_compute_likelihood(engine, imagemodel): assert np.argmax(logl) == np.argwhere(test_slopes == SPECTRAL_SLOPE) assert np.all(np.array(logl) < 0) + + +def test_sparse_c(kern_array): + """Here kernel must be a 2-D array already, of shape (N_ker, N_k_convolved)""" + + # test typical case n_k = n_kc and i=0 + n_k = kern_array.shape[1] + i_zero = 0 + matrix = au._sparse_c(kern_array, n_k, i_zero) + + # TODO: add more here + + +@pytest.fixture(scope="module") +def tikhoTests(): + """Make a TikhoTests dictionary""" + + + return au.TikhoTests({'factors': factors, + 'solution': sln, + 'error': err, + 'reg': reg, + 'grid': wave_grid}) + + + +def test_tikho_tests(tikhoTests): + + assert False \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index b0ee7aa57b..81f69a2fe2 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -3,29 +3,7 @@ from jwst.extract_1d.soss_extract import atoca_utils as au import numpy as np -def test_arange_2d(): - starts = np.array([3,4,5]) - stops = np.ones(starts.shape)*7 - out = au.arange_2d(starts, stops) - - bad = -1 - expected_out = np.array([ - [3,4,5,6], - [4,5,6,bad], - [5,6,bad,bad] - ]) - assert np.allclose(out, expected_out) - - # test bad input catches - starts_wrong_shape = starts[1:] - with pytest.raises(ValueError): - au.arange_2d(starts_wrong_shape, stops) - - stops_too_small = np.copy(stops) - stops_too_small[2] = 4 - with pytest.raises(ValueError): - au.arange_2d(starts, stops_too_small) # wavelengths have min, max (1.5, 4.0) and a bit of non-linearity @@ -62,442 +40,11 @@ def trace_profile_o2(wave_map_o2): return trace_profile*thrpt[:,None] -@pytest.mark.parametrize("dispersion_axis", [0,1]) -def test_get_wv_map_bounds(wave_map, dispersion_axis): - """ - top is the low-wavelength end, bottom is high-wavelength end - """ - if dispersion_axis == 0: - wave_flip = wave_map.T - else: - wave_flip = wave_map - wave_top, wave_bottom = au._get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) - - # flip the results back so we can re-use the same tests - if dispersion_axis == 0: - wave_top = wave_top.T - wave_bottom = wave_bottom.T - - diff = (WAVELENGTHS[1:]-WAVELENGTHS[:-1])/2 - diff_lower = np.insert(diff,0,diff[0]) - diff_upper = np.append(diff,diff[-1]) - wave_top_expected = WAVELENGTHS-diff_lower - wave_bottom_expected = WAVELENGTHS+diff_upper - - # basic test - assert wave_top.shape == wave_bottom.shape == (wave_map.shape[0],)+WAVELENGTHS.shape - assert np.allclose(wave_top[0], wave_top_expected) - assert np.allclose(wave_bottom[0], wave_bottom_expected) - - # test skip bad pixel - assert wave_top[1,3] == 0 - assert wave_bottom[1,3] == 0 - - # test bad input error raises - with pytest.raises(ValueError): - au._get_wv_map_bounds(wave_flip, dispersion_axis=2) - - - -@pytest.mark.parametrize("dispersion_axis", [0,1]) -def test_get_wave_p_or_m(wave_map, dispersion_axis): - """ - Check that the plus and minus side is correctly identified - for strictly ascending and strictly descending wavelengths. - - TODO: test that given a float64, output is float64 - """ - wave_reverse = np.fliplr(wave_map) - if dispersion_axis == 0: - wave_flip = wave_map.T - wave_reverse = wave_reverse.T - else: - wave_flip = wave_map - - wave_p_0, wave_m_0 = au.get_wave_p_or_m(wave_flip, dispersion_axis=dispersion_axis) - wave_p_1, wave_m_1 = au.get_wave_p_or_m(wave_reverse, dispersion_axis=dispersion_axis) - - if dispersion_axis==0: - wave_p_0 = wave_p_0.T - wave_m_0 = wave_m_0.T - wave_p_1 = wave_p_1.T - wave_m_1 = wave_m_1.T - assert np.all(wave_p_0 >= wave_m_0) - assert np.allclose(wave_p_0, np.fliplr(wave_p_1)) - assert np.allclose(wave_m_0, np.fliplr(wave_m_1)) - - -def test_get_wave_p_or_m_not_ascending(wave_map): - with pytest.raises(ValueError): - wave_map[0,5] = 2 # make it not strictly ascending - au.get_wave_p_or_m(wave_map, dispersion_axis=1) - - -FIBONACCI = np.array([1,1,2,3,5,8,13,21,35], dtype=float) -@pytest.mark.parametrize("n_os", [1,5]) -def test_oversample_grid(n_os): - - oversample = au.oversample_grid(FIBONACCI, n_os) - - # oversample_grid is supposed to remove any duplicates, and there is a duplicate - # in FIBONACCI. So the output should be 4 times the size of FIBONACCI minus 1 - assert oversample.size == n_os*(FIBONACCI.size - 1) - (n_os-1) - assert oversample.min() == FIBONACCI.min() - assert oversample.max() == FIBONACCI.max() - - # test whether np.interp could have been used instead - grid = np.arange(0, FIBONACCI.size, 1/n_os) - wls = np.unique(np.interp(grid, np.arange(FIBONACCI.size), FIBONACCI)) - assert np.allclose(oversample, wls) - - -@pytest.mark.parametrize("os_factor", [1,2,5]) -def test_oversample_irregular(os_factor): - """Test oversampling to a grid with irregular spacing""" - # oversampling function removes duplicates, - # this is tested in previous test, and just complicates counting for this test - # for FIBONACCI, unique is just removing zeroth element - fib_unq = np.unique(FIBONACCI) - n_os = np.ones((fib_unq.size-1,), dtype=int) - n_os[2:5] = os_factor - n_os[3] = os_factor*2 - # this gives n_os = [1 1 2 4 2] for os_factor = 2 - - oversample = au.oversample_grid(fib_unq, n_os) - - # test no oversampling was done on the elements where not requested - assert np.allclose(oversample[0:2], fib_unq[0:2]) - assert np.allclose(oversample[-1:], fib_unq[-1:]) - - # test output shape. - assert oversample.size == np.sum(n_os)+1 - - # test that this could have been done easily with np.interp - intervals = 1/n_os - intervals = np.insert(np.repeat(intervals, n_os),0,0) - grid = np.cumsum(intervals) - wls = np.interp(grid, np.arange(fib_unq.size), fib_unq) - assert wls.size == oversample.size - assert np.allclose(oversample, wls) - - # test that n_os shape must match input shape - 1 - with pytest.raises(ValueError): - au.oversample_grid(fib_unq, n_os[:-1]) - - -@pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) -def test_extrapolate_grid(wave_range): - - extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) - - assert extrapolated.max() > wave_range[1] - assert extrapolated.min() < wave_range[0] - assert np.all(extrapolated[1:] >= extrapolated[:-1]) - - # if interpolation not needed on either side, should return the original - if wave_range == (2.1, 3.9): - assert extrapolated is WAVELENGTHS - - -def test_extrapolate_catch_failed_converge(): - # give wavelengths some non-linearity - wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 - with pytest.raises(RuntimeError): - au._extrapolate_grid(WAVELENGTHS, wave_range, 1) - - -def test_extrapolate_bad_inputs(): - with pytest.raises(ValueError): - au._extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) - with pytest.raises(ValueError): - au._extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) - - -def test_grid_from_map(wave_map, trace_profile): - """Covers expected behavior of grid_from_map, including coverage of a previous bug - where bad wavelengths were not being ignored properly""" - - wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=None) - - # expected output is very near WAVELENGTHS+0.2 because that's what all the high-weight - # rows of the wave_map are set to. - assert np.allclose(wave_grid, wave_map[2]) - - # test custom wave_range - wave_range = [wave_map[2,2], wave_map[2,-2]+0.01] - wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) - assert np.allclose(wave_grid, wave_map[2,2:-1]) - - # test custom wave_range with extrapolation - wave_range = [wave_map[2,2], wave_map[2,-1]+1] - wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) - assert len(wave_grid) > len(wave_map[2,2:]) - n_inside = wave_map[2,2:].size - assert np.allclose(wave_grid[:n_inside], wave_map[2,2:]) - with pytest.raises(ValueError): - au.grid_from_map(wave_map, trace_profile, wave_range=[0.5,0.9]) -def xsinx(x): - return x*np.sin(x) -def test_estim_integration_error(): - """ - Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. - TODO: Find something more meaningful to test here - """ - - n = 11 - grid = np.linspace(0, np.pi, n) - err, rel_err = au._estim_integration_err(grid, xsinx) - - assert len(rel_err) == n-1 - assert np.all(rel_err >= 0) - assert np.all(rel_err < 1) - - -@pytest.mark.parametrize("max_iter, rtol", [(1,1e-3), (10, 1e-9), (10, 1e-3), (1, 1e-9)]) -def test_adapt_grid(max_iter, rtol): - """ - Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. - """ - - input_grid = np.linspace(0, np.pi, 11) - input_grid_diff = input_grid[1] - input_grid[0] - max_grid_size = 100 - grid, is_converged = au._adapt_grid(input_grid, - xsinx, - max_grid_size, - max_iter=max_iter, - rtol=rtol) - - # ensure grid respects max_grid_size and max_iter in all cases - assert len(grid) <= max_grid_size - grid_diff = grid[1:] - grid[:-1] - assert np.min(grid_diff) >= input_grid_diff/(2**max_iter) - - numerical_integral = np.trapz(xsinx(grid), grid) - - # ensure this converges for at least one of our test cases - if max_iter == 10 and rtol == 1e-3: - assert is_converged - - if is_converged: - # test error of the answer is smaller than rtol - assert np.isclose(numerical_integral, np.pi, rtol=rtol) - # test that success was a stop condition - assert len(grid) < max_grid_size - - # test stop conditions - elif max_iter == 10: - # ensure hitting max_grid_size returns an array of exactly length max_grid_size - assert len(grid) == max_grid_size - elif max_iter == 1: - # ensure hitting max_iter can stop iteration before max_grid_size reached - assert len(grid) <= 2*len(input_grid) - - -def test_adapt_grid_bad_inputs(): - with pytest.raises(ValueError): - # input grid larger than max_grid_size - au._adapt_grid(np.array([1,2,3]), xsinx, 2) - - -def test_trim_grids(): - - grid_range = (-3, 3) - grid0 = np.linspace(-3, 0, 4) # kept entirely. - grid1 = np.linspace(-3, 0, 16) # removed entirely. Finer spacing doesn't matter, preceded by grid0 - grid2 = np.linspace(0, 3, 5) # kept from 0 to 3 - grid3 = np.linspace(-4, 4, 5) # removed entirely. Outside of grid_range and the rest is superseded - - all_grids = [grid0, grid1, grid2, grid3] - trimmed_grids = au._trim_grids(all_grids, grid_range) - - assert len(trimmed_grids) == len(all_grids) - assert trimmed_grids[0].size == grid0.size - assert trimmed_grids[1].size == 0 - assert trimmed_grids[2].size == grid2.size - assert trimmed_grids[3].size == 0 - - -def test_make_combined_adaptive_grid(): - """see also tests of _adapt_grid and _trim_grids for more detailed tests""" - - grid_range = (0, np.pi) - grid0 = np.linspace(0, np.pi/2, 6) # kept entirely. - grid1 = np.linspace(0, np.pi/2, 15) # removed entirely. Finer spacing doesn't matter, preceded by grid0 - grid2 = np.linspace(np.pi/2, np.pi, 11) # kept from pi/2 to pi - - # purposely make same lower index for grid2 as upper index for grid0 to test uniqueness of output - - all_grids = [grid0, grid1, grid2] - all_estimate = [xsinx, xsinx, xsinx] - - rtol = 1e-3 - combined_grid = au.make_combined_adaptive_grid(all_grids, all_estimate, grid_range, - max_iter=10, rtol=rtol, max_total_size=100) - - numerical_integral = np.trapz(xsinx(combined_grid), combined_grid) - - assert np.unique(combined_grid).size == combined_grid.size - assert np.isclose(numerical_integral, np.pi, rtol=rtol) - - -def test_throughput_soss(): - - wavelengths = np.linspace(2,5,10) - throughputs = np.ones_like(wavelengths) - interpolator = au.ThroughputSOSS(wavelengths, throughputs) - - # test that it returns 1 for all wavelengths inside range - interp = interpolator(wavelengths) - assert np.allclose(interp[1:-1], throughputs[1:-1]) - assert interp[0] == 0 - assert interp[-1] == 0 - - # test that it returns 0 for all wavelengths outside range - wavelengths_outside = np.linspace(1,1.5,5) - interp = interpolator(wavelengths_outside) - assert np.all(interp == 0) - - # test ValueError raise for shape mismatch - with pytest.raises(ValueError): - au.ThroughputSOSS(wavelengths, throughputs[:-1]) - - -@pytest.fixture(scope="module") -def kernel_init(): - """ - Toy model of the JWST kernel. The kernel is a triangle function - with maximum at the center, uniform in wavelength. - """ - n_os = 3 - n_pix = 25 # full width of kernel - n_wave = 20 - wave_range = (2.0, 5.0) - wavelengths = np.linspace(*wave_range, n_wave) - kernel_width = n_os*n_pix - (n_os - 1) - ctr_idx = kernel_width//2 - - wave_kernel = np.ones((kernel_width, wavelengths.size), dtype=float)*wavelengths[None,:] - triangle_function = ctr_idx - np.abs(ctr_idx - np.arange(0, kernel_width)) - kernel = np.ones((kernel_width, wavelengths.size), dtype=float)*triangle_function[:,None] - kernel/=np.max(kernel) - - return wave_kernel, kernel, n_pix - - -def test_webb_kernel(kernel_init): - - min_trace = 2.5 - max_trace = 3.5 - n_trace = 100 - wave_trace = np.linspace(min_trace, max_trace, n_trace) - (wave_kernel, kernel, n_pix) = kernel_init - - # instantiate the kernel object - kern = au.WebbKernel(wave_kernel, kernel, wave_trace, n_pix) - - # basic ensure that the input is stored and shapes - assert kern.n_pix == n_pix - assert kern.wave_kernels.shape == kern.kernels.shape - - # test that kernel and wave_kernel have been clipped to only keep wavelengths on detector - assert np.all(kern.wave_kernels >= min_trace) - assert np.all(kern.wave_kernels <= max_trace) - assert kern.wave_kernels.shape[0] == wave_kernel.shape[0] - assert kern.wave_kernels.shape[1] < wave_kernel.shape[1] - - # test that pixels is mirrored around the center and has zero at center - assert kern.pixels.size == wave_kernel.shape[0] - assert np.allclose(kern.pixels + kern.pixels[::-1], 0) - assert kern.pixels[kern.pixels.size//2] == 0 - - # test that wave_center has same shape as wavelength axis of wave_kernel - # but contains values that are in wave_trace - assert kern.wave_center.size == kern.wave_kernels.shape[1] - assert all(np.isin(kern.wave_center, wave_trace)) - - # test min value - assert kern.min_value > 0 - assert np.isin(kern.min_value, kern.kernels) - assert isinstance(kern.min_value, float) - - # test the polynomial fit has the proper shape. hard-coded to a first-order, i.e., linear fit - # since the throughput is constant in wavelength, the slopes should be close to zero - # and the y-intercepts should be close to kern.wave_center - # especially with so few points. just go with 10 percent, should catch egregious changes - assert kern.poly.shape == (kern.wave_kernels.shape[1], 2) - assert np.allclose(kern.poly[:,0], 0, atol=1e-1) - assert np.allclose(kern.poly[:,1], kern.wave_center, atol=1e-1) - - # test interpolation function, which takes in a pixel and a wavelength and returns a throughput - # this should return the triangle function at all wavelengths and zero outside range - pix_half = n_pix//2 - wl_test = np.linspace(min_trace, max_trace, 10) - pixels_test = np.array([-pix_half-1, -pix_half/2, 0, - pix_half/2, pix_half, pix_half+1]) - expected = np.array([0, 0.5, 1, 0.5, 0, 0]) - interp = kern.f_ker(pixels_test, wl_test) - - assert interp.shape == (pixels_test.size, wl_test.size) - diff = interp[:,1:] - interp[:,:-1] - assert np.allclose(diff, 0) - assert np.allclose(interp[:,0], expected, rtol=1e-3) - - # call the kernel object directly - # this takes a wavelength and a central wavelength of the kernel, - # then converts to pixels to use self.f_ker internally - assert kern(wl_test, wl_test).ndim == 1 - assert np.allclose(kern(wl_test, wl_test), 1) - # expect one wl spacing to correspond to one kernel pixel - wl_spacing = wave_trace[1] - wave_trace[0] - expected = 1 - 1/pix_half - assert np.allclose(kern(wl_test, wl_test - wl_spacing), expected) - - # test that clipping to minimum value works right - wl_on_edge = wl_test + (pix_half*wl_spacing - 0.0001) - assert np.allclose(kern(wl_test, wl_on_edge), kern.min_value) - - # both inputs need to be same shape - with pytest.raises(ValueError): - kern(wl_test, wl_test[:-1]) - - -@pytest.fixture(scope="module") -def kernel(kernel_init): - (wave_kernel, kernel, n_pix) = kernel_init - min_trace = 2.5 - max_trace = 3.5 - n_trace = 100 - wave_trace = np.linspace(min_trace, max_trace, n_trace) - return au.WebbKernel(wave_kernel, kernel, wave_trace, n_pix) - - -def test_fct_to_array(kernel): - - thresh = 1e-5 - grid = np.linspace(2.0, 4.0, 50) - grid_range = [0, grid.size] - - kern_array = au._fct_to_array(kernel, grid, grid_range, thresh) - - # check shape - assert kern_array.ndim == 2 - assert kern_array.shape[1] == grid.size - assert kern_array.shape[0]%2 == 1 - - # test that the max is at the center - kern_slice = kern_array[:,kern_array.shape[1]//2] - assert kern_slice[kern_slice.size//2] == np.max(kern_slice) - - # test that weights have been applied at edges - kern_center = kern_array[kern_array.shape[0]//2] - assert np.isclose(kern_center[0], kern_center[1]/2) - assert np.isclose(kern_center[-1], kern_center[-2]/2) @pytest.fixture(scope="module") @@ -545,24 +92,6 @@ def test_get_c_matrix(kernel): # Test invalid kernel input (wrong dimensions) -def test_finite_first_diff(): - - wave_grid = np.linspace(0, 2*np.pi, 100) - test_0 = np.ones_like(wave_grid) - test_sin = np.sin(wave_grid) - - first_d = au.finite_first_d(wave_grid) - assert first_d.size == (wave_grid.size - 1)*2 - - # test trivial example returning zeros for constant - f0 = first_d.dot(test_0) - assert np.allclose(f0, 0) - - # test derivative of sin returns cos - wave_between = (wave_grid[1:] + wave_grid[:-1])/2 - f_sin = first_d.dot(test_sin) - assert np.allclose(f_sin, np.cos(wave_between), atol=1e-3) - @pytest.fixture(scope="module") def tikhoTests(): From f42a74e1267a63319402dc731f903ddcbef2fc12 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 4 Dec 2024 15:56:28 -0500 Subject: [PATCH 26/35] split fixtures into conftest, add more tests of kernels --- jwst/extract_1d/soss_extract/atoca_utils.py | 5 +- .../extract_1d/soss_extract/tests/conftest.py | 229 ++++++ .../soss_extract/tests/test_atoca.py | 726 +----------------- .../soss_extract/tests/test_atoca_utils.py | 467 +++++++++-- 4 files changed, 657 insertions(+), 770 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/conftest.py diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index 108dc360d3..fe905136ac 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1812,14 +1812,13 @@ def best_factor(self, mode='curvature'): def try_solve_two_methods(matrix, result): - + """on rare occasions spsolve's approximation of the matrix is not appropriate + and fails on good input data. revert to different solver""" with warnings.catch_warnings(): warnings.filterwarnings(action='error', category=MatrixRankWarning) try: return spsolve(matrix, result) except MatrixRankWarning: - # on rare occasions spsolve's approximation of the matrix is not appropriate - # and fails on good input data. revert to different solver log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.') return lsqr(matrix, result)[0] diff --git a/jwst/extract_1d/soss_extract/tests/conftest.py b/jwst/extract_1d/soss_extract/tests/conftest.py new file mode 100644 index 0000000000..2dd068c4d3 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/conftest.py @@ -0,0 +1,229 @@ +import pytest +import numpy as np +from functools import partial +from jwst.extract_1d.soss_extract import atoca +from jwst.extract_1d.soss_extract import atoca_utils as au + +""" +Create a miniature, slightly simplified model of the SOSS detector/optics. +Use those to instantiate an extraction engine, use the engine to create mock data +from a known input spectrum, and then check if the engine can retrieve that spectrum +from the data. + +The model has the following features: +- Factor-of-10 smaller along each dimension +- Similar wavelengths for each of the two orders +- Partially overlapping traces for the two orders +- Randomly-selected bad pixels in the data +- Wave grid of size ~100 with varying resolution +- Triangle function throughput for each spectral order +- Kernel set to unity (for now) +""" + +DATA_SHAPE = (25,200) +WAVE_BNDS_O1 = [2.8, 0.8] +WAVE_BNDS_O2 = [1.4, 0.5] +WAVE_BNDS_GRID = [0.7, 2.7] +ORDER1_SCALING = 20.0 +ORDER2_SCALING = 2.0 +SPECTRAL_SLOPE = 2 + + +@pytest.fixture(scope="package") +def wave_map(): + wave_ord1 = np.linspace(WAVE_BNDS_O1[0], WAVE_BNDS_O1[1], DATA_SHAPE[1]) + wave_ord1 = np.ones(DATA_SHAPE)*wave_ord1[np.newaxis, :] + + wave_ord2 = np.linspace(WAVE_BNDS_O2[0], WAVE_BNDS_O2[1], DATA_SHAPE[1]) + wave_ord2 = np.ones(DATA_SHAPE)*wave_ord2[np.newaxis,:] + # add a small region of zeros to mimic what is input to the step from ref files + wave_ord2[:,190:] = 0.0 + + return [wave_ord1, wave_ord2] + + +@pytest.fixture(scope="package") +def trace_profile(wave_map): + """order 2 is partially on top of, partially not on top of order 1 + give order 2 some slope to simulate that""" + # order 1 + DATA_SHAPE = wave_map[0].shape + ord1 = np.zeros((DATA_SHAPE[0])) + ord1[3:9] = 0.2 + ord1[2] = 0.1 + ord1[9] = 0.1 + profile_ord1 = np.ones(DATA_SHAPE)*ord1[:, np.newaxis] + + # order 2 + yy, xx = np.meshgrid(np.arange(DATA_SHAPE[0]), np.arange(DATA_SHAPE[1])) + yy = yy.astype(np.float32) - xx.astype(np.float32)*0.08 + yy = yy.T + + profile_ord2 = np.zeros_like(yy) + full = (yy >= 3) & (yy < 9) + half0 = (yy >= 9) & (yy < 11) + half1 = (yy >= 1) & (yy < 3) + profile_ord2[full] = 0.2 + profile_ord2[half0] = 0.1 + profile_ord2[half1] = 0.1 + + return [profile_ord1, profile_ord2] + + +@pytest.fixture(scope="package") +def wave_grid(): + """wave_grid has smaller spacings in some places than others + and is not backwards order like the wave map + Two duplicates are in there on purpose for testing""" + lo0 = np.linspace(WAVE_BNDS_GRID[0], 1.2, 16) + hi = np.linspace(1.2, 1.7, 46) + lo2 = np.linspace(1.7, WAVE_BNDS_GRID[1], 31) + return np.concatenate([lo0, hi, lo2]) + + +@pytest.fixture(scope="package") +def throughput(): + """make a triangle function for each order but with different peak wavelength + """ + + def filter_function(wl, wl_max): + """Set free parameters to roughly mimic throughput functions on main""" + maxthru = 0.4 + thresh = 0.01 + scaling = 0.3 + dist = np.abs(wl - wl_max) + thru = maxthru - dist*scaling + thru[thru0] = 0 + if cut_low is not None: + trace[:,:cut_low] = 1 + if cut_hi is not None: + trace[:,cut_hi:] = 1 + return trace.astype(bool) + + trace_o1 = mask_from_trace(trace_profile[0], cut_low=0, cut_hi=199) + trace_o2 = mask_from_trace(trace_profile[1], cut_low=0, cut_hi=175) + return [trace_o1, trace_o2] + + +@pytest.fixture(scope="package") +def detector_mask(wave_map): + """Add a few random bad pixels""" + rng = np.random.default_rng(42) + mask = np.zeros(DATA_SHAPE, dtype=bool) + bad = rng.choice(mask.size, 100) + bad = np.unravel_index(bad, DATA_SHAPE) + mask[bad] = 1 + return mask + + +@pytest.fixture(scope="package") +def engine(wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + detector_mask, +): + return atoca.ExtractionEngine(wave_map, + trace_profile, + throughput, + kernels_unity, + wave_grid, + mask_trace_profile, + global_mask=detector_mask) + + +def f_lam(wl, m=SPECTRAL_SLOPE, b=0): + """ + Estimator for flux as function of wavelength + Returns linear function of wl with slope m and intercept b + + This function is also used in this test suite as + """ + return m*wl + b + + +@pytest.fixture(scope="package") +def imagemodel(engine, detector_mask): + """ + use engine.rebuild to make an image model from an expected f(lambda). + Then we can ensure it round-trips + """ + + rng = np.random.default_rng(seed=42) + shp = engine.trace_profile[0].shape + + # make the detector bad values NaN, but leave the trace masks alone + # in reality, of course, this is backward: detector bad values + # would be determined from data + data = engine.rebuild(f_lam, fill_value=0.0) + data[detector_mask] = np.nan + + # add noise + noise_scaling = 3e-5 + data += noise_scaling*rng.standard_normal(shp) + + # error random, all positive, but also always larger than a certain value + # to avoid very large values of data / error + error = noise_scaling*(rng.standard_normal(shp)**2 + 0.5) + + return data, error \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index 8de18bd5ce..57fae0c43a 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -1,640 +1,11 @@ import pytest import numpy as np from functools import partial +from scipy.sparse import csr_matrix from jwst.extract_1d.soss_extract import atoca from jwst.extract_1d.soss_extract import atoca_utils as au - - -""" -Tests for atoca.py and atoca_utils.py - -Create a miniature, slightly simplified model of the SOSS detector/optics. -Use those to instantiate an extraction engine, use the engine to create mock data -from a known input spectrum, and then check if the engine can retrieve that spectrum -from the data. - -The model has the following features: -- Factor-of-10 smaller along each dimension -- Similar wavelengths for each of the two orders -- Partially overlapping traces for the two orders -- Randomly-selected bad pixels in the data -- Wave grid of size ~100 with varying resolution -- Triangle function throughput for each spectral order -- Kernel set to unity (for now) -""" - - -DATA_SHAPE = (25,200) -WAVE_BNDS_O1 = [2.8, 0.8] -WAVE_BNDS_O2 = [1.4, 0.5] -WAVE_BNDS_GRID = [0.7, 2.7] -ORDER1_SCALING = 20.0 -ORDER2_SCALING = 2.0 -SPECTRAL_SLOPE = 2 - - -@pytest.fixture(scope="module") -def wave_map(): - """Trying for a roughly factor-of-10 shrink on each axis - but otherwise relatively faithful reproduction of real detector behavior""" - wave_ord1 = np.linspace(WAVE_BNDS_O1[0], WAVE_BNDS_O1[1], DATA_SHAPE[1]) - wave_ord1 = np.ones(DATA_SHAPE)*wave_ord1[np.newaxis, :] - - wave_ord2 = np.linspace(WAVE_BNDS_O2[0], WAVE_BNDS_O2[1], DATA_SHAPE[1]) - wave_ord2 = np.ones(DATA_SHAPE)*wave_ord2[np.newaxis,:] - # add a small region of zeros to mimic what is input to the step from ref files - wave_ord2[:,190:] = 0.0 - - return [wave_ord1, wave_ord2] - - -@pytest.fixture(scope="module") -def trace_profile(wave_map): - """ - order 2 is partially on top of, partially not on top of order 1 - give order 2 some slope to simulate that - """ - # order 1 - DATA_SHAPE = wave_map[0].shape - ord1 = np.zeros((DATA_SHAPE[0])) - ord1[3:9] = 0.2 - ord1[2] = 0.1 - ord1[9] = 0.1 - profile_ord1 = np.ones(DATA_SHAPE)*ord1[:, np.newaxis] - - # order 2 - yy, xx = np.meshgrid(np.arange(DATA_SHAPE[0]), np.arange(DATA_SHAPE[1])) - yy = yy.astype(np.float32) - xx.astype(np.float32)*0.08 - yy = yy.T - - profile_ord2 = np.zeros_like(yy) - full = (yy >= 3) & (yy < 9) - half0 = (yy >= 9) & (yy < 11) - half1 = (yy >= 1) & (yy < 3) - profile_ord2[full] = 0.2 - profile_ord2[half0] = 0.1 - profile_ord2[half1] = 0.1 - - return [profile_ord1, profile_ord2] - - -@pytest.fixture(scope="module") -def wave_grid(): - """wave_grid has smaller spacings in some places than others - and is not backwards order like the wave map - Two duplicates are in there on purpose for testing""" - lo0 = np.linspace(WAVE_BNDS_GRID[0], 1.2, 16) - hi = np.linspace(1.2, 1.7, 46) - lo2 = np.linspace(1.7, WAVE_BNDS_GRID[1], 31) - return np.concatenate([lo0, hi, lo2]) - - -@pytest.fixture(scope="module") -def throughput(): - """make a triangle function for each order but with different peak wavelength - """ - - def filter_function(wl, wl_max): - """Set free parameters to roughly mimic throughput functions on main""" - maxthru = 0.4 - thresh = 0.01 - scaling = 0.3 - dist = np.abs(wl - wl_max) - thru = maxthru - dist*scaling - thru[thru0] = 0 - if cut_low is not None: - trace[:,:cut_low] = 1 - if cut_hi is not None: - trace[:,cut_hi:] = 1 - return trace.astype(bool) - - trace_o1 = mask_from_trace(trace_profile[0], cut_low=0, cut_hi=199) - trace_o2 = mask_from_trace(trace_profile[1], cut_low=0, cut_hi=175) - return [trace_o1, trace_o2] - - -@pytest.fixture(scope="module") -def detector_mask(wave_map): - """Add a few random bad pixels""" - rng = np.random.default_rng(42) - mask = np.zeros(DATA_SHAPE, dtype=bool) - bad = rng.choice(mask.size, 100) - bad = np.unravel_index(bad, DATA_SHAPE) - mask[bad] = 1 - return mask - - -@pytest.fixture(scope="module") -def engine(wave_map, - trace_profile, - throughput, - kernels_unity, - wave_grid, - mask_trace_profile, - detector_mask, -): - return atoca.ExtractionEngine(wave_map, - trace_profile, - throughput, - kernels_unity, - wave_grid, - mask_trace_profile, - global_mask=detector_mask) - - -def test_arange_2d(): - - starts = np.array([3,4,5]) - stops = np.ones(starts.shape)*7 - out = au.arange_2d(starts, stops) - - bad = -1 - expected_out = np.array([ - [3,4,5,6], - [4,5,6,bad], - [5,6,bad,bad] - ]) - assert np.allclose(out, expected_out) - - # test bad input catches - starts_wrong_shape = starts[1:] - with pytest.raises(ValueError): - au.arange_2d(starts_wrong_shape, stops) - - stops_too_small = np.copy(stops) - stops_too_small[2] = 4 - with pytest.raises(ValueError): - au.arange_2d(starts, stops_too_small) - - -@pytest.mark.parametrize("dispersion_axis", [0,1]) -def test_get_wv_map_bounds(wave_map, dispersion_axis): - """ - top is the low-wavelength end, bottom is high-wavelength end - """ - wave_map = wave_map[0].copy() - wave_map[1,3] = -1 #test skip of bad value - wavelengths = wave_map[0] - - if dispersion_axis == 0: - wave_flip = wave_map.T - else: - wave_flip = wave_map - wave_top, wave_bottom = au._get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) - - # flip the results back so we can re-use the same tests - if dispersion_axis == 0: - wave_top = wave_top.T - wave_bottom = wave_bottom.T - - diff = (wavelengths[1:]-wavelengths[:-1])/2 - diff_lower = np.insert(diff,0,diff[0]) - diff_upper = np.append(diff,diff[-1]) - wave_top_expected = wavelengths-diff_lower - wave_bottom_expected = wavelengths+diff_upper - - # basic test - assert wave_top.shape == wave_bottom.shape == (wave_map.shape[0],)+wavelengths.shape - assert np.allclose(wave_top[0], wave_top_expected) - assert np.allclose(wave_bottom[0], wave_bottom_expected) - - # test skip bad pixel - assert wave_top[1,3] == 0 - assert wave_bottom[1,3] == 0 - - # test bad input error raises - with pytest.raises(ValueError): - au._get_wv_map_bounds(wave_flip, dispersion_axis=2) - - -@pytest.mark.parametrize("dispersion_axis", [0,1]) -def test_get_wave_p_or_m(wave_map, dispersion_axis): - """ - Check that the plus and minus side is correctly identified - for strictly ascending and strictly descending wavelengths. - """ - wave_map = wave_map[0].copy() - wave_reverse = np.fliplr(wave_map) - if dispersion_axis == 0: - wave_flip = wave_map.T - wave_reverse = wave_reverse.T - else: - wave_flip = wave_map - - wave_p_0, wave_m_0 = au.get_wave_p_or_m(wave_flip, dispersion_axis=dispersion_axis) - wave_p_1, wave_m_1 = au.get_wave_p_or_m(wave_reverse, dispersion_axis=dispersion_axis) - - if dispersion_axis==0: - wave_p_0 = wave_p_0.T - wave_m_0 = wave_m_0.T - wave_p_1 = wave_p_1.T - wave_m_1 = wave_m_1.T - assert np.all(wave_p_0 >= wave_m_0) - assert np.allclose(wave_p_0, np.fliplr(wave_p_1)) - assert np.allclose(wave_m_0, np.fliplr(wave_m_1)) - - -def test_get_wave_p_or_m_not_ascending(wave_map): - wave_map = wave_map[0].copy() - with pytest.raises(ValueError): - wave_map[0,5] = 2 # make it not strictly ascending - au.get_wave_p_or_m(wave_map, dispersion_axis=1) - - -FIBONACCI = np.array([1,1,2,3,5,8,13,21,35], dtype=float) -@pytest.mark.parametrize("n_os", [1,5]) -def test_oversample_grid(n_os): - - oversample = au.oversample_grid(FIBONACCI, n_os) - - # oversample_grid is supposed to remove any duplicates, and there is a duplicate - # in FIBONACCI. So the output should be 4 times the size of FIBONACCI minus 1 - assert oversample.size == n_os*(FIBONACCI.size - 1) - (n_os-1) - assert oversample.min() == FIBONACCI.min() - assert oversample.max() == FIBONACCI.max() - - # test whether np.interp could have been used instead - grid = np.arange(0, FIBONACCI.size, 1/n_os) - wls = np.unique(np.interp(grid, np.arange(FIBONACCI.size), FIBONACCI)) - assert np.allclose(oversample, wls) - - -@pytest.mark.parametrize("os_factor", [1,2,5]) -def test_oversample_irregular(os_factor): - """Test oversampling to a grid with irregular spacing""" - # oversampling function removes duplicates, - # this is tested in previous test, and just complicates counting for this test - # for FIBONACCI, unique is just removing zeroth element - fib_unq = np.unique(FIBONACCI) - n_os = np.ones((fib_unq.size-1,), dtype=int) - n_os[2:5] = os_factor - n_os[3] = os_factor*2 - # this gives n_os = [1 1 2 4 2] for os_factor = 2 - - oversample = au.oversample_grid(fib_unq, n_os) - - # test no oversampling was done on the elements where not requested - assert np.allclose(oversample[0:2], fib_unq[0:2]) - assert np.allclose(oversample[-1:], fib_unq[-1:]) - - # test output shape. - assert oversample.size == np.sum(n_os)+1 - - # test that this could have been done easily with np.interp - intervals = 1/n_os - intervals = np.insert(np.repeat(intervals, n_os),0,0) - grid = np.cumsum(intervals) - wls = np.interp(grid, np.arange(fib_unq.size), fib_unq) - assert wls.size == oversample.size - assert np.allclose(oversample, wls) - - # test that n_os shape must match input shape - 1 - with pytest.raises(ValueError): - au.oversample_grid(fib_unq, n_os[:-1]) - - -WAVELENGTHS = np.linspace(1.5, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) -@pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) -def test_extrapolate_grid(wave_range): - - extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) - - assert extrapolated.max() > wave_range[1] - assert extrapolated.min() < wave_range[0] - assert np.all(extrapolated[1:] >= extrapolated[:-1]) - - # if interpolation not needed on either side, should return the original - if wave_range == (2.1, 3.9): - assert extrapolated is WAVELENGTHS - - -def test_extrapolate_catch_failed_converge(): - # give wavelengths some non-linearity - wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 - with pytest.raises(RuntimeError): - au._extrapolate_grid(WAVELENGTHS, wave_range, 1) - - -def test_extrapolate_bad_inputs(): - with pytest.raises(ValueError): - au._extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) - with pytest.raises(ValueError): - au._extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) - - -def test_grid_from_map(wave_map, trace_profile): - """Covers expected behavior of grid_from_map, including coverage of a previous bug - where bad wavelengths were not being ignored properly""" - - wave_map = wave_map[0].copy() - wavelengths = wave_map[0][::-1] - trace_profile = trace_profile[0].copy() - wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=None) - - assert np.allclose(wave_grid, wavelengths) - - # test custom wave_range - wave_range = [wavelengths[2], wavelengths[-2]+0.01] - wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) - assert np.allclose(wave_grid, wavelengths[2:-1]) - - # test custom wave_range with extrapolation - wave_range = [wavelengths[2], wavelengths[-1]+1] - wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) - assert len(wave_grid) > len(wavelengths[2:]) - n_inside = wavelengths[2:].size - assert np.allclose(wave_grid[:n_inside], wavelengths[2:]) - - with pytest.raises(ValueError): - au.grid_from_map(wave_map, trace_profile, wave_range=[0.1,0.2]) - - -def xsinx(x): - return x*np.sin(x) - - -def test_estim_integration_error(): - """ - Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. - """ - - n = 11 - grid = np.linspace(0, np.pi, n) - err, rel_err = au._estim_integration_err(grid, xsinx) - - assert len(rel_err) == n-1 - assert np.all(rel_err >= 0) - assert np.all(rel_err < 1) - - -@pytest.mark.parametrize("max_iter, rtol", [(1,1e-3), (10, 1e-9), (10, 1e-3), (1, 1e-9)]) -def test_adapt_grid(max_iter, rtol): - """ - Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. - """ - - input_grid = np.linspace(0, np.pi, 11) - input_grid_diff = input_grid[1] - input_grid[0] - max_grid_size = 100 - grid, is_converged = au._adapt_grid(input_grid, - xsinx, - max_grid_size, - max_iter=max_iter, - rtol=rtol) - - # ensure grid respects max_grid_size and max_iter in all cases - assert len(grid) <= max_grid_size - grid_diff = grid[1:] - grid[:-1] - assert np.min(grid_diff) >= input_grid_diff/(2**max_iter) - - numerical_integral = np.trapz(xsinx(grid), grid) - - # ensure this converges for at least one of our test cases - if max_iter == 10 and rtol == 1e-3: - assert is_converged - - if is_converged: - # test error of the answer is smaller than rtol - assert np.isclose(numerical_integral, np.pi, rtol=rtol) - # test that success was a stop condition - assert len(grid) < max_grid_size - - # test stop conditions - elif max_iter == 10: - # ensure hitting max_grid_size returns an array of exactly length max_grid_size - assert len(grid) == max_grid_size - elif max_iter == 1: - # ensure hitting max_iter can stop iteration before max_grid_size reached - assert len(grid) <= 2*len(input_grid) - - -def test_adapt_grid_bad_inputs(): - with pytest.raises(ValueError): - # input grid larger than max_grid_size - au._adapt_grid(np.array([1,2,3]), xsinx, 2) - - -def test_trim_grids(): - - grid_range = (-3, 3) - grid0 = np.linspace(-3, 0, 4) # kept entirely. - grid1 = np.linspace(-3, 0, 16) # removed entirely. Finer spacing doesn't matter, preceded by grid0 - grid2 = np.linspace(0, 3, 5) # kept from 0 to 3 - grid3 = np.linspace(-4, 4, 5) # removed entirely. Outside of grid_range and the rest is superseded - - all_grids = [grid0, grid1, grid2, grid3] - trimmed_grids = au._trim_grids(all_grids, grid_range) - - assert len(trimmed_grids) == len(all_grids) - assert trimmed_grids[0].size == grid0.size - assert trimmed_grids[1].size == 0 - assert trimmed_grids[2].size == grid2.size - assert trimmed_grids[3].size == 0 - - -def test_make_combined_adaptive_grid(): - """see also tests of _adapt_grid and _trim_grids for more detailed tests""" - - grid_range = (0, np.pi) - grid0 = np.linspace(0, np.pi/2, 6) # kept entirely. - grid1 = np.linspace(0, np.pi/2, 15) # removed entirely. Finer spacing doesn't matter, preceded by grid0 - grid2 = np.linspace(np.pi/2, np.pi, 11) # kept from pi/2 to pi - - # purposely make same lower index for grid2 as upper index for grid0 to test uniqueness of output - - all_grids = [grid0, grid1, grid2] - all_estimate = [xsinx, xsinx, xsinx] - - rtol = 1e-3 - combined_grid = au.make_combined_adaptive_grid(all_grids, all_estimate, grid_range, - max_iter=10, rtol=rtol, max_total_size=100) - - numerical_integral = np.trapz(xsinx(combined_grid), combined_grid) - - assert np.unique(combined_grid).size == combined_grid.size - assert np.isclose(numerical_integral, np.pi, rtol=rtol) - - -def test_throughput_soss(): - - wavelengths = np.linspace(2,5,10) - throughputs = np.ones_like(wavelengths) - interpolator = au.ThroughputSOSS(wavelengths, throughputs) - - # test that it returns 1 for all wavelengths inside range - interp = interpolator(wavelengths) - assert np.allclose(interp[1:-1], throughputs[1:-1]) - assert interp[0] == 0 - assert interp[-1] == 0 - - # test that it returns 0 for all wavelengths outside range - wavelengths_outside = np.linspace(1,1.5,5) - interp = interpolator(wavelengths_outside) - assert np.all(interp == 0) - - # test ValueError raise for shape mismatch - with pytest.raises(ValueError): - au.ThroughputSOSS(wavelengths, throughputs[:-1]) - - -def test_webb_kernel(webb_kernels, wave_map): - - wave_trace = wave_map[0][0] - min_trace, max_trace = np.min(wave_trace), np.max(wave_trace) - kern = webb_kernels[0] - - # basic ensure that the input is stored and shapes - assert kern.wave_kernels.shape == kern.kernels.shape - - # test that pixels is mirrored around the center and has zero at center - assert np.allclose(kern.pixels + kern.pixels[::-1], 0) - assert kern.pixels[kern.pixels.size//2] == 0 - - # test that wave_center has same shape as wavelength axis of wave_kernel - # but contains values that are in wave_trace - assert kern.wave_center.size == kern.wave_kernels.shape[1] - assert all(np.isin(kern.wave_center, wave_trace)) - - # test min value - assert kern.min_value > 0 - assert np.isin(kern.min_value, kern.kernels) - assert isinstance(kern.min_value, float) - - # test the polynomial fit has the proper shape. hard-coded to a first-order, i.e., linear fit - # since the throughput is constant in wavelength, the slopes should be close to zero - # and the y-intercepts should be close to kern.wave_center - # especially with so few points. just go with 10 percent, should catch egregious changes - assert kern.poly.shape == (kern.wave_kernels.shape[1], 2) - assert np.allclose(kern.poly[:,0], 0, atol=1e-1) - assert np.allclose(kern.poly[:,1], kern.wave_center, atol=1e-1) - - # test interpolation function, which takes in a pixel and a wavelength and returns a throughput - # this should return the triangle function at all wavelengths and zero outside range - pix_half = kern.n_pix//2 - wl_test = np.linspace(min_trace, max_trace, 10) - pixels_test = np.array([-pix_half-1, 0, pix_half, pix_half+1]) - - data_in = kern.kernels[:,0] - m = kern.min_value - expected = np.array([m, np.max(data_in), m, m]) - - interp = kern.f_ker(pixels_test, wl_test) - assert interp.shape == (pixels_test.size, wl_test.size) - diff = interp[:,1:] - interp[:,:-1] - assert np.allclose(diff, 0) - assert np.allclose(interp[:,0], expected, rtol=1e-3) - - # call the kernel object directly - # this takes a wavelength and a central wavelength of the kernel, - # then converts to pixels to use self.f_ker internally - assert kern(wl_test, wl_test).ndim == 1 - assert np.allclose(kern(wl_test, wl_test), np.max(data_in)) - - # both inputs need to be same shape - with pytest.raises(ValueError): - kern(wl_test, wl_test[:-1]) - - -def test_get_c_matrix(kernel): - - #TODO: what to use for grid? is it / can it be the same as wave_trace? - # I think it can be the same but does not need to be, and would be a better - # test if it were different, because the kernel is an interpolator that was - # created using wavelengths that are included in wave_trace. - matrix = au.get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5) - - # test with WebbKernel as the kernel - # ensure normalized - # ensure sparse - - - # test where input kernel is a 2-D array instead of callable - - - # test where input kernel is size 1 - - - # test where i_bounds is not None - - - # Test invalid kernel input (wrong dimensions) - - -def test_finite_first_diff(): - - wave_grid = np.linspace(0, 2*np.pi, 100) - test_0 = np.ones_like(wave_grid) - test_sin = np.sin(wave_grid) - - first_d = au.finite_first_d(wave_grid) - assert first_d.size == (wave_grid.size - 1)*2 - - # test trivial example returning zeros for constant - f0 = first_d.dot(test_0) - assert np.allclose(f0, 0) - - # test derivative of sin returns cos - wave_between = (wave_grid[1:] + wave_grid[:-1])/2 - f_sin = first_d.dot(test_sin) - assert np.allclose(f_sin, np.cos(wave_between), atol=1e-3) +from jwst.extract_1d.soss_extract.tests.conftest import ( + SPECTRAL_SLOPE, f_lam, DATA_SHAPE, WAVE_BNDS_O1, WAVE_BNDS_O2) def test_extraction_engine_init( @@ -690,6 +61,7 @@ def test_extraction_engine_init( # in order 2 no restriction on the shortwave end so get the full extent assert engine.i_bounds[0][0] > 0 assert engine.i_bounds[1][1] < engine.n_wavepoints + # TODO: off-by-one error here. why does this fail? # check what this looks like on a real run on main # assert engine.i_bounds[0][1] == engine.n_wavepoints @@ -843,10 +215,20 @@ def new_thru_f(wl): engine.update_throughput([new_thru_f, new_thru_f]) -def test_create_kernels(engine): - # TODO: make a non-trivial kernel fixture to work with this - # trivial example already done - pass +def test_create_kernels(webb_kernels, engine): + """test_atoca_utisl.test_get_c_matrix already tests the creation + of individual kernels for different input types, here just ensure + the options get passed into that function properly""" + + kernels_0 = engine._create_kernels(webb_kernels, None) + kernels_1 = engine._create_kernels([None, None], None) + + for kernel_list in [kernels_0, kernels_1]: + assert len(kernel_list) == 2 + for order in [0,1]: + kern = kernel_list[order] + assert isinstance(kern, csr_matrix) + assert kern.dtype == np.float64 def test_wave_grid_c(engine): @@ -898,10 +280,6 @@ def test_get_pixel_mapping(engine): with pytest.raises(AttributeError): engine.get_pixel_mapping(1, quick=True) - # TODO: follow the shape of every sparse matrix through this function, see if - # it makes sense or if it would make more sense to mask differently - # TODO: test that the math is actually correct using a trivial example (somehow) - def test_rebuild(engine): @@ -917,43 +295,6 @@ def test_rebuild(engine): assert np.allclose(np.isnan(detector_model_nans), engine.general_mask) -def f_lam(wl, m=SPECTRAL_SLOPE, b=0): - """ - Estimator for flux as function of wavelength - Returns linear function of wl with slope m and intercept b - - This function is also used in this test suite as - """ - return m*wl + b - - -@pytest.fixture(scope="module") -def imagemodel(engine, detector_mask): - """ - use engine.rebuild to make an image model from an expected f(lambda). - Then we can ensure it round-trips - """ - - rng = np.random.default_rng(seed=42) - shp = engine.trace_profile[0].shape - - # make the detector bad values NaN, but leave the trace masks alone - # in reality, of course, this is backward: detector bad values - # would be determined from data - data = engine.rebuild(f_lam, fill_value=0.0) - data[detector_mask] = np.nan - - # add noise - noise_scaling = 3e-5 - data += noise_scaling*rng.standard_normal(shp) - - # error random, all positive, but also always larger than a certain value - # to avoid very large values of data / error - error = noise_scaling*(rng.standard_normal(shp)**2 + 0.5) - - return data, error - - def test_build_sys(imagemodel, engine): data, error = imagemodel @@ -962,7 +303,6 @@ def test_build_sys(imagemodel, engine): assert matrix.shape == (result.size, result.size) - def test_get_detector_model(imagemodel, engine): data, error = imagemodel @@ -979,7 +319,6 @@ def test_estimate_tikho_factors(engine): factor = engine.estimate_tikho_factors(f_lam) assert isinstance(factor, float) - # very approximate calculation of tik fac looks like # n_pixels = (~engine.mask).sum() # flux = f_lam(engine.wave_grid) @@ -1072,32 +411,3 @@ def test_compute_likelihood(engine, imagemodel): assert np.argmax(logl) == np.argwhere(test_slopes == SPECTRAL_SLOPE) assert np.all(np.array(logl) < 0) - - -def test_sparse_c(kern_array): - """Here kernel must be a 2-D array already, of shape (N_ker, N_k_convolved)""" - - # test typical case n_k = n_kc and i=0 - n_k = kern_array.shape[1] - i_zero = 0 - matrix = au._sparse_c(kern_array, n_k, i_zero) - - # TODO: add more here - - -@pytest.fixture(scope="module") -def tikhoTests(): - """Make a TikhoTests dictionary""" - - - return au.TikhoTests({'factors': factors, - 'solution': sln, - 'error': err, - 'reg': reg, - 'grid': wave_grid}) - - - -def test_tikho_tests(tikhoTests): - - assert False \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py index 81f69a2fe2..533814affe 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca_utils.py @@ -4,108 +4,457 @@ import numpy as np +def test_arange_2d(): + starts = np.array([3,4,5]) + stops = np.ones(starts.shape)*7 + out = au.arange_2d(starts, stops) -# wavelengths have min, max (1.5, 4.0) and a bit of non-linearity -WAVELENGTHS = np.linspace(1.5, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) -@pytest.fixture(scope="function") -def wave_map(): - wave_map = np.array([ - WAVELENGTHS, - WAVELENGTHS+0.2, - WAVELENGTHS+0.2, - WAVELENGTHS+0.2, - WAVELENGTHS+0.4, + bad = -1 + expected_out = np.array([ + [3,4,5,6], + [4,5,6,bad], + [5,6,bad,bad] ]) + assert np.allclose(out, expected_out) + + # test bad input catches + starts_wrong_shape = starts[1:] + with pytest.raises(ValueError): + au.arange_2d(starts_wrong_shape, stops) + + stops_too_small = np.copy(stops) + stops_too_small[2] = 4 + with pytest.raises(ValueError): + au.arange_2d(starts, stops_too_small) + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wv_map_bounds(wave_map, dispersion_axis): + """ + top is the low-wavelength end, bottom is high-wavelength end + """ + wave_map = wave_map[0].copy() wave_map[1,3] = -1 #test skip of bad value - return wave_map + wavelengths = wave_map[0] + + if dispersion_axis == 0: + wave_flip = wave_map.T + else: + wave_flip = wave_map + wave_top, wave_bottom = au._get_wv_map_bounds(wave_flip, dispersion_axis=dispersion_axis) + + # flip the results back so we can re-use the same tests + if dispersion_axis == 0: + wave_top = wave_top.T + wave_bottom = wave_bottom.T + + diff = (wavelengths[1:]-wavelengths[:-1])/2 + diff_lower = np.insert(diff,0,diff[0]) + diff_upper = np.append(diff,diff[-1]) + wave_top_expected = wavelengths-diff_lower + wave_bottom_expected = wavelengths+diff_upper + + # basic test + assert wave_top.shape == wave_bottom.shape == (wave_map.shape[0],)+wavelengths.shape + assert np.allclose(wave_top[0], wave_top_expected) + assert np.allclose(wave_bottom[0], wave_bottom_expected) + + # test skip bad pixel + assert wave_top[1,3] == 0 + assert wave_bottom[1,3] == 0 + + # test bad input error raises + with pytest.raises(ValueError): + au._get_wv_map_bounds(wave_flip, dispersion_axis=2) + + +@pytest.mark.parametrize("dispersion_axis", [0,1]) +def test_get_wave_p_or_m(wave_map, dispersion_axis): + """ + Check that the plus and minus side is correctly identified + for strictly ascending and strictly descending wavelengths. + """ + wave_map = wave_map[0].copy() + wave_reverse = np.fliplr(wave_map) + if dispersion_axis == 0: + wave_flip = wave_map.T + wave_reverse = wave_reverse.T + else: + wave_flip = wave_map + + wave_p_0, wave_m_0 = au.get_wave_p_or_m(wave_flip, dispersion_axis=dispersion_axis) + wave_p_1, wave_m_1 = au.get_wave_p_or_m(wave_reverse, dispersion_axis=dispersion_axis) + + if dispersion_axis==0: + wave_p_0 = wave_p_0.T + wave_m_0 = wave_m_0.T + wave_p_1 = wave_p_1.T + wave_m_1 = wave_m_1.T + assert np.all(wave_p_0 >= wave_m_0) + assert np.allclose(wave_p_0, np.fliplr(wave_p_1)) + assert np.allclose(wave_m_0, np.fliplr(wave_m_1)) + + +def test_get_wave_p_or_m_not_ascending(wave_map): + wave_map = wave_map[0].copy() + with pytest.raises(ValueError): + wave_map[0,5] = 2 # make it not strictly ascending + au.get_wave_p_or_m(wave_map, dispersion_axis=1) + + +FIBONACCI = np.array([1,1,2,3,5,8,13,21,35], dtype=float) +@pytest.mark.parametrize("n_os", [1,5]) +def test_oversample_grid(n_os): + + oversample = au.oversample_grid(FIBONACCI, n_os) + + # oversample_grid is supposed to remove any duplicates, and there is a duplicate + # in FIBONACCI. So the output should be 4 times the size of FIBONACCI minus 1 + assert oversample.size == n_os*(FIBONACCI.size - 1) - (n_os-1) + assert oversample.min() == FIBONACCI.min() + assert oversample.max() == FIBONACCI.max() + + # test whether np.interp could have been used instead + grid = np.arange(0, FIBONACCI.size, 1/n_os) + wls = np.unique(np.interp(grid, np.arange(FIBONACCI.size), FIBONACCI)) + assert np.allclose(oversample, wls) + + +@pytest.mark.parametrize("os_factor", [1,2,5]) +def test_oversample_irregular(os_factor): + """Test oversampling to a grid with irregular spacing""" + # oversampling function removes duplicates, + # this is tested in previous test, and just complicates counting for this test + # for FIBONACCI, unique is just removing zeroth element + fib_unq = np.unique(FIBONACCI) + n_os = np.ones((fib_unq.size-1,), dtype=int) + n_os[2:5] = os_factor + n_os[3] = os_factor*2 + # this gives n_os = [1 1 2 4 2] for os_factor = 2 + + oversample = au.oversample_grid(fib_unq, n_os) + + # test no oversampling was done on the elements where not requested + assert np.allclose(oversample[0:2], fib_unq[0:2]) + assert np.allclose(oversample[-1:], fib_unq[-1:]) + + # test output shape. + assert oversample.size == np.sum(n_os)+1 + + # test that this could have been done easily with np.interp + intervals = 1/n_os + intervals = np.insert(np.repeat(intervals, n_os),0,0) + grid = np.cumsum(intervals) + wls = np.interp(grid, np.arange(fib_unq.size), fib_unq) + assert wls.size == oversample.size + assert np.allclose(oversample, wls) + + # test that n_os shape must match input shape - 1 + with pytest.raises(ValueError): + au.oversample_grid(fib_unq, n_os[:-1]) -@pytest.fixture(scope="function") -def wave_map_o2(wave_map): - return np.copy(wave_map) - 1.0 +WAVELENGTHS = np.linspace(1.5, 3.0, 50) + np.sin(np.linspace(0, np.pi/2, 50)) +@pytest.mark.parametrize("wave_range", [(2.1, 3.9), (1.8, 4.5)]) +def test_extrapolate_grid(wave_range): + + extrapolated = au._extrapolate_grid(WAVELENGTHS, wave_range, 1) + assert extrapolated.max() > wave_range[1] + assert extrapolated.min() < wave_range[0] + assert np.all(extrapolated[1:] >= extrapolated[:-1]) -@pytest.fixture(scope="function") -def trace_profile(wave_map): - thrpt = np.array([0.01, 0.95, 1.0, 0.8, 0.01]) - trace_profile = np.ones_like(wave_map) - return trace_profile*thrpt[:,None] + # if interpolation not needed on either side, should return the original + if wave_range == (2.1, 3.9): + assert extrapolated is WAVELENGTHS -@pytest.fixture(scope="function") -def trace_profile_o2(wave_map_o2): - thrpt = np.array([0.001, 0.01, 0.01, 0.2, 0.99]) - trace_profile = np.ones_like(wave_map_o2) - return trace_profile*thrpt[:,None] +def test_extrapolate_catch_failed_converge(): + # give wavelengths some non-linearity + wave_range = WAVELENGTHS.min(), WAVELENGTHS.max()+4.0 + with pytest.raises(RuntimeError): + au._extrapolate_grid(WAVELENGTHS, wave_range, 1) +def test_extrapolate_bad_inputs(): + with pytest.raises(ValueError): + au._extrapolate_grid(WAVELENGTHS, (2.9, 2.1)) + with pytest.raises(ValueError): + au._extrapolate_grid(WAVELENGTHS, (4.1, 4.2)) +def test_grid_from_map(wave_map, trace_profile): + """Covers expected behavior of grid_from_map, including coverage of a previous bug + where bad wavelengths were not being ignored properly""" + wave_map = wave_map[0].copy() + wavelengths = wave_map[0][::-1] + trace_profile = trace_profile[0].copy() + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=None) + assert np.allclose(wave_grid, wavelengths) + # test custom wave_range + wave_range = [wavelengths[2], wavelengths[-2]+0.01] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert np.allclose(wave_grid, wavelengths[2:-1]) + # test custom wave_range with extrapolation + wave_range = [wavelengths[2], wavelengths[-1]+1] + wave_grid = au.grid_from_map(wave_map, trace_profile, wave_range=wave_range) + assert len(wave_grid) > len(wavelengths[2:]) + n_inside = wavelengths[2:].size + assert np.allclose(wave_grid[:n_inside], wavelengths[2:]) -@pytest.fixture(scope="module") -def kern_array(kernel): + with pytest.raises(ValueError): + au.grid_from_map(wave_map, trace_profile, wave_range=[0.1,0.2]) - return au._fct_to_array(kernel, np.linspace(2.0, 4.0, 50), [0, 50], 1e-5) +def xsinx(x): + return x*np.sin(x) -def test_sparse_c(kern_array): - """Here kernel must be a 2-D array already, of shape (N_ker, N_k_convolved)""" - # test typical case n_k = n_kc and i=0 - n_k = kern_array.shape[1] - i_zero = 0 - matrix = au._sparse_c(kern_array, n_k, i_zero) +def test_estim_integration_error(): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ - # TODO: add more here + n = 11 + grid = np.linspace(0, np.pi, n) + err, rel_err = au._estim_integration_err(grid, xsinx) + assert len(rel_err) == n-1 + assert np.all(rel_err >= 0) + assert np.all(rel_err < 1) -def test_get_c_matrix(kernel): - """See also test_fct_to_array and test_sparse_c for more detailed tests - of functions called by this one""" +@pytest.mark.parametrize("max_iter, rtol", [(1,1e-3), (10, 1e-9), (10, 1e-3), (1, 1e-9)]) +def test_adapt_grid(max_iter, rtol): + """ + Use as truth the x sin(x) from 0 to pi, has an analytic solution == pi. + """ - #TODO: what to use for grid? is it / can it be the same as wave_trace? - # I think it can be the same but does not need to be, and would be a better - # test if it were different, because the kernel is an interpolator that was - # created using wavelengths that are included in wave_trace. - matrix = au.get_c_matrix(kernel, grid, i_bounds=None, thresh=1e-5) + input_grid = np.linspace(0, np.pi, 11) + input_grid_diff = input_grid[1] - input_grid[0] + max_grid_size = 100 + grid, is_converged = au._adapt_grid(input_grid, + xsinx, + max_grid_size, + max_iter=max_iter, + rtol=rtol) - # test with WebbKernel as the kernel - # ensure normalized - # ensure sparse + # ensure grid respects max_grid_size and max_iter in all cases + assert len(grid) <= max_grid_size + grid_diff = grid[1:] - grid[:-1] + assert np.min(grid_diff) >= input_grid_diff/(2**max_iter) + numerical_integral = np.trapz(xsinx(grid), grid) - # test where input kernel is a 2-D array instead of callable + # ensure this converges for at least one of our test cases + if max_iter == 10 and rtol == 1e-3: + assert is_converged + if is_converged: + # test error of the answer is smaller than rtol + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + # test that success was a stop condition + assert len(grid) < max_grid_size - # test where input kernel is size 1 + # test stop conditions + elif max_iter == 10: + # ensure hitting max_grid_size returns an array of exactly length max_grid_size + assert len(grid) == max_grid_size + elif max_iter == 1: + # ensure hitting max_iter can stop iteration before max_grid_size reached + assert len(grid) <= 2*len(input_grid) - # test where i_bounds is not None +def test_adapt_grid_bad_inputs(): + with pytest.raises(ValueError): + # input grid larger than max_grid_size + au._adapt_grid(np.array([1,2,3]), xsinx, 2) - # Test invalid kernel input (wrong dimensions) +def test_trim_grids(): + + grid_range = (-3, 3) + grid0 = np.linspace(-3, 0, 4) # kept entirely. + grid1 = np.linspace(-3, 0, 16) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(0, 3, 5) # kept from 0 to 3 + grid3 = np.linspace(-4, 4, 5) # removed entirely. Outside of grid_range and the rest is superseded + + all_grids = [grid0, grid1, grid2, grid3] + trimmed_grids = au._trim_grids(all_grids, grid_range) + + assert len(trimmed_grids) == len(all_grids) + assert trimmed_grids[0].size == grid0.size + assert trimmed_grids[1].size == 0 + assert trimmed_grids[2].size == grid2.size + assert trimmed_grids[3].size == 0 + + +def test_make_combined_adaptive_grid(): + """see also tests of _adapt_grid and _trim_grids for more detailed tests""" + + grid_range = (0, np.pi) + grid0 = np.linspace(0, np.pi/2, 6) # kept entirely. + grid1 = np.linspace(0, np.pi/2, 15) # removed entirely. Finer spacing doesn't matter, preceded by grid0 + grid2 = np.linspace(np.pi/2, np.pi, 11) # kept from pi/2 to pi + + # purposely make same lower index for grid2 as upper index for grid0 to test uniqueness of output + + all_grids = [grid0, grid1, grid2] + all_estimate = [xsinx, xsinx, xsinx] + + rtol = 1e-3 + combined_grid = au.make_combined_adaptive_grid(all_grids, all_estimate, grid_range, + max_iter=10, rtol=rtol, max_total_size=100) + + numerical_integral = np.trapz(xsinx(combined_grid), combined_grid) + + assert np.unique(combined_grid).size == combined_grid.size + assert np.isclose(numerical_integral, np.pi, rtol=rtol) + + +def test_throughput_soss(): + wavelengths = np.linspace(2,5,10) + throughputs = np.ones_like(wavelengths) + interpolator = au.ThroughputSOSS(wavelengths, throughputs) + # test that it returns 1 for all wavelengths inside range + interp = interpolator(wavelengths) + assert np.allclose(interp[1:-1], throughputs[1:-1]) + assert interp[0] == 0 + assert interp[-1] == 0 -@pytest.fixture(scope="module") -def tikhoTests(): - """Make a TikhoTests dictionary""" + # test that it returns 0 for all wavelengths outside range + wavelengths_outside = np.linspace(1,1.5,5) + interp = interpolator(wavelengths_outside) + assert np.all(interp == 0) + # test ValueError raise for shape mismatch + with pytest.raises(ValueError): + au.ThroughputSOSS(wavelengths, throughputs[:-1]) - return au.TikhoTests({'factors': factors, - 'solution': sln, - 'error': err, - 'reg': reg, - 'grid': wave_grid}) +def test_webb_kernel(webb_kernels, wave_map): + wave_trace = wave_map[0][0] + min_trace, max_trace = np.min(wave_trace), np.max(wave_trace) + kern = webb_kernels[0] -def test_tikho_tests(tikhoTests): + # basic ensure that the input is stored and shapes + assert kern.wave_kernels.shape == kern.kernels.shape - assert False \ No newline at end of file + # test that pixels is mirrored around the center and has zero at center + assert np.allclose(kern.pixels + kern.pixels[::-1], 0) + assert kern.pixels[kern.pixels.size//2] == 0 + + # test that wave_center has same shape as wavelength axis of wave_kernel + # but contains values that are in wave_trace + assert kern.wave_center.size == kern.wave_kernels.shape[1] + assert all(np.isin(kern.wave_center, wave_trace)) + + # test min value + assert kern.min_value > 0 + assert np.isin(kern.min_value, kern.kernels) + assert isinstance(kern.min_value, float) + + # test the polynomial fit has the proper shape. hard-coded to a first-order, i.e., linear fit + # since the throughput is constant in wavelength, the slopes should be close to zero + # and the y-intercepts should be close to kern.wave_center + # especially with so few points. just go with 10 percent, should catch egregious changes + assert kern.poly.shape == (kern.wave_kernels.shape[1], 2) + assert np.allclose(kern.poly[:,0], 0, atol=1e-1) + assert np.allclose(kern.poly[:,1], kern.wave_center, atol=1e-1) + + # test interpolation function, which takes in a pixel and a wavelength and returns a throughput + # this should return the triangle function at all wavelengths and zero outside range + pix_half = kern.n_pix//2 + wl_test = np.linspace(min_trace, max_trace, 10) + pixels_test = np.array([-pix_half-1, 0, pix_half, pix_half+1]) + + data_in = kern.kernels[:,0] + m = kern.min_value + expected = np.array([m, np.max(data_in), m, m]) + + interp = kern.f_ker(pixels_test, wl_test) + assert interp.shape == (pixels_test.size, wl_test.size) + diff = interp[:,1:] - interp[:,:-1] + assert np.allclose(diff, 0) + assert np.allclose(interp[:,0], expected, rtol=1e-3) + + # call the kernel object directly + # this takes a wavelength and a central wavelength of the kernel, + # then converts to pixels to use self.f_ker internally + assert kern(wl_test, wl_test).ndim == 1 + assert np.allclose(kern(wl_test, wl_test), np.max(data_in)) + + # both inputs need to be same shape + with pytest.raises(ValueError): + kern(wl_test, wl_test[:-1]) + + +def test_finite_first_diff(): + + wave_grid = np.linspace(0, 2*np.pi, 100) + test_0 = np.ones_like(wave_grid) + test_sin = np.sin(wave_grid) + + first_d = au.finite_first_d(wave_grid) + assert first_d.size == (wave_grid.size - 1)*2 + + # test trivial example returning zeros for constant + f0 = first_d.dot(test_0) + assert np.allclose(f0, 0) + + # test derivative of sin returns cos + wave_between = (wave_grid[1:] + wave_grid[:-1])/2 + f_sin = first_d.dot(test_sin) + assert np.allclose(f_sin, np.cos(wave_between), atol=1e-3) + + +def test_get_c_matrix(kernels_unity, webb_kernels, wave_grid): + """See also test_fct_to_array and test_sparse_c for more detailed tests + of functions called by this one""" + + # only need to test one order + kern = webb_kernels[0] + matrix = au.get_c_matrix(kern, wave_grid, i_bounds=None) + + # ensure proper shape + assert matrix.shape == (wave_grid.size, wave_grid.size) + assert matrix.dtype == np.float64 + + # ensure normalized + assert matrix.sum() == matrix.shape[0] + + # test where input kernel is a 2-D array instead of callable + i_bounds = [0, len(wave_grid)] + kern_array = au._fct_to_array(kern, wave_grid, i_bounds, 1e-5) + matrix_from_array = au.get_c_matrix(kern_array, wave_grid, i_bounds=i_bounds) + assert np.allclose(matrix.toarray(), matrix_from_array.toarray()) + + # test where input kernel is size 1 + kern_unity = kernels_unity[0] + matrix_from_unity = au.get_c_matrix(kern_unity, wave_grid, i_bounds=i_bounds) + assert matrix_from_unity.shape == (wave_grid.size, wave_grid.size) + + + # test where i_bounds is not None + i_bounds = [10, wave_grid.size-10] + matrix_ibnds = au.get_c_matrix(kern, wave_grid, i_bounds=i_bounds) + expected_shape = (wave_grid[i_bounds[0]:i_bounds[1]].size, wave_grid.size) + assert matrix_ibnds.shape == expected_shape + + # Test invalid kernel input (wrong dimensions) + with pytest.raises(ValueError): + kern_array_bad = kern_array[np.newaxis, ...] + au.get_c_matrix(kern_array_bad, wave_grid, i_bounds=i_bounds) + + # Test invalid kernel input (odd shape) + with pytest.raises(ValueError): + kern_array_bad = kern_array[1:,1:] + au.get_c_matrix(kern_array_bad, wave_grid, i_bounds=i_bounds) From 1797037438d1b8049902ba3e682698162ef4e22a Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 4 Dec 2024 16:14:06 -0500 Subject: [PATCH 27/35] ruff check according to stcal rules --- jwst/extract_1d/soss_extract/atoca.py | 6 +-- jwst/extract_1d/soss_extract/atoca_utils.py | 16 +++----- jwst/extract_1d/soss_extract/pastasoss.py | 6 +-- jwst/extract_1d/soss_extract/soss_extract.py | 39 ++++++++----------- .../soss_extract/tests/test_atoca.py | 4 +- 5 files changed, 30 insertions(+), 41 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index c3b4d6d934..36fca3b4d3 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -257,7 +257,7 @@ def _create_kernels(self, kernels, c_kwargs): kwargs_ker = {'thresh': ker.min_value} except AttributeError: # take the get_c_matrix defaults - kwargs_ker = dict() + kwargs_ker = {} c_kwargs.append(kwargs_ker) # ...or same for all orders if only a dictionary was given @@ -701,7 +701,7 @@ def best_tikho_factor(self, tests, fit_mode): # # idx_max = np.argmax(tests['factors']) # # idx_to_keep[idx_max] = True # # Make new tests with remaining factors -# new_tests = dict() +# new_tests = {} # for key in tests: # if key != 'grid': # new_tests[key] = tests[key][idx_to_keep] @@ -716,7 +716,7 @@ def best_tikho_factor(self, tests, fit_mode): list_mode = [fit_mode] # Evaluate best factor with different methods - results = dict() + results = {} for mode in list_mode: best_fac = tests.best_factor(mode=mode) results[mode] = best_fac diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index fe905136ac..d04953c78b 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -1615,7 +1615,9 @@ def _linear(z): LOSS_FUNCTIONS = {'soft_l1': _soft_l1, 'cauchy': _cauchy, 'linear': _linear} - +DEFAULT_THRESH_DERIVATIVE = {'chi2':1e-5, + 'chi2_soft_l1':1e-4, + 'chi2_cauchy':1e-3} class TikhoTests(dict): """ @@ -1623,10 +1625,6 @@ class TikhoTests(dict): All the tests are stored in the attribute `tests` as a dictionary """ - DEFAULT_THRESH_DERIVATIVE = (('chi2', 1e-5), - ('chi2_soft_l1', 1e-4), - ('chi2_cauchy', 1e-3)) - def __init__(self, test_dict, default_chi2='chi2_cauchy'): """ Parameters @@ -1643,9 +1641,7 @@ def __init__(self, test_dict, default_chi2='chi2_cauchy'): # Save attributes self.n_points = n_points self.default_chi2 = default_chi2 - self.default_thresh = {chi2_type: thresh - for (chi2_type, thresh) - in self.DEFAULT_THRESH_DERIVATIVE} + self.default_thresh = DEFAULT_THRESH_DERIVATIVE # Initialize so it behaves like a dictionary super().__init__(test_dict) @@ -1679,8 +1675,8 @@ def _compute_chi2(self, loss): try: loss = LOSS_FUNCTIONS[loss] except KeyError as e: - keys = [key for key in LOSS_FUNCTIONS.keys()] - msg = f'loss={loss} not a valid key. Must be one of {keys} or callable.' + msg = (f"loss={loss} not a valid key." + f"Must be one of {[LOSS_FUNCTIONS.keys()]} or callable.") raise e(msg) # Compute the reduced chi^2 for all tests diff --git a/jwst/extract_1d/soss_extract/pastasoss.py b/jwst/extract_1d/soss_extract/pastasoss.py index 332bd1a287..7751294243 100644 --- a/jwst/extract_1d/soss_extract/pastasoss.py +++ b/jwst/extract_1d/soss_extract/pastasoss.py @@ -367,9 +367,9 @@ def _extrapolate_to_wavegrid(w_grid, wavelength, quantity): Array The interpolated quantities """ - sorted = np.argsort(wavelength) - q = quantity[sorted] - w = wavelength[sorted] + sort_i = np.argsort(wavelength) + q = quantity[sort_i] + w = wavelength[sort_i] # Determine the slope on the right of the array slope_right = (q[-1] - q[-2]) / (w[-1] - w[-2]) diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 16f4a73802..ea24312d8e 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -7,8 +7,8 @@ from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels import dqflags, SossWaveGridModel -from ..extract import populate_time_keywords -from ...lib import pipe_utils +from jwst.extract_1d.extract import populate_time_keywords +from jwst.lib import pipe_utils from astropy.nddata.bitmask import bitfield_to_boolean_mask from .soss_syscor import make_background_mask, soss_background @@ -81,7 +81,7 @@ def get_ref_file_args(ref_files): # The throughput curves for order 1 and 2. - throughput_index_dict = dict() + throughput_index_dict = {} for i, throughput in enumerate(pastasoss_ref.throughputs): throughput_index_dict[throughput.spectral_order] = i @@ -98,7 +98,7 @@ def get_ref_file_args(ref_files): # Take the centroid of each trace as a grid to project the WebbKernel # WebbKer needs a 2d input, so artificially add axis wave_maps = [wavemap_o1, wavemap_o2] - centroid = dict() + centroid = {} for wv_map, order in zip(wave_maps, [1, 2]): wv_cent = np.zeros((wv_map.shape[1])) @@ -314,7 +314,7 @@ def _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os): # Build native grid for each orders. spectral_orders = [2, 1] - grids_ord = dict() + grids_ord = {} for sp_ord in spectral_orders: grids_ord[sp_ord] = _get_grid_from_trace(ref_files, sp_ord, n_os=n_os) @@ -344,14 +344,13 @@ def flat_fct(wv): all_estimates = [estimate, estimate, flat_fct] # Generate the combined grid - kwargs = dict(rtol=rtol, max_total_size=max_grid_size, max_iter=30) + kwargs = {"rtol":rtol, "max_total_size":max_grid_size, "max_iter":30} return make_combined_adaptive_grid(all_grids, all_estimates, wv_range, **kwargs) def _append_tiktests(test_a, test_b): - out = dict() - + out = {} for key in test_a: out[key] = np.append(test_a[key], test_b[key], axis=0) @@ -621,7 +620,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Create a new instance of the engine for evaluating the trace model. # This allows bad pixels and pixels below the threshold to be reconstructed as well. # Model the order 1 and order 2 trace separately. - tracemodels = dict() + tracemodels = {} for i_order, order in enumerate(order_list): @@ -705,8 +704,7 @@ def _compute_box_weights(ref_files, shape, width): order_list.append(trace.spectral_order) # Extract each order from order list - box_weights = dict() - wavelengths = dict() + box_weights, wavelengths = {}, {} order_str = {order: f'Order {order}' for order in order_list} for order_integer in order_list: # Order string-name is used more often than integer-name @@ -735,7 +733,7 @@ def _decontaminate_image(scidata_bkg, tracemodels, subarray): mod_order_list = tracemodels.keys() # Create dictionaries for the output images. - decontaminated_data = dict() + decontaminated_data = {} log.debug('Performing the decontamination.') @@ -862,15 +860,13 @@ def _extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='m """ # Init models with an empty dictionary if not given if tracemodels is None: - tracemodels = dict() + tracemodels = {} # Which orders to extract (extract the ones with given box aperture). order_list = box_weights.keys() # Create dictionaries for the output spectra. - fluxes = dict() - fluxerrs = dict() - npixels = dict() + fluxes, fluxerrs, npixels = {}, {}, {} log.info('Performing the box extraction.') @@ -966,7 +962,7 @@ def run_extract1d(input_model, pastasoss_ref_name, specprofile_ref = datamodels.SpecProfileModel(specprofile_ref_name) speckernel_ref = datamodels.SpecKernelModel(speckernel_ref_name) - ref_files = dict() + ref_files = {} ref_files['pastasoss'] = pastasoss_ref ref_files['specprofile'] = specprofile_ref ref_files['speckernel'] = speckernel_ref @@ -1013,8 +1009,7 @@ def run_extract1d(input_model, pastasoss_ref_name, output_references = datamodels.SossExtractModel() output_references.update(input_model) - all_tracemodels = dict() - all_box_weights = dict() + all_tracemodels, all_box_weights = {}, {} # Convert to Cube if datamodels is an ImageModel if isinstance(input_model, datamodels.ImageModel): @@ -1077,7 +1072,7 @@ def run_extract1d(input_model, pastasoss_ref_name, if soss_filter == 'CLEAR' and generate_model: # Model the image. - kwargs = dict() + kwargs = {} kwargs['estimate'] = estimate kwargs['tikfac'] = soss_kwargs['tikfac'] kwargs['max_grid_size'] = soss_kwargs['max_grid_size'] @@ -1104,7 +1099,7 @@ def run_extract1d(input_model, pastasoss_ref_name, raise ValueError(msg) else: # Return empty tracemodels and no spec_list - tracemodels = dict() + tracemodels = {} spec_list = None # Decontaminate the data using trace models (if tracemodels not empty) @@ -1118,7 +1113,7 @@ def run_extract1d(input_model, pastasoss_ref_name, bad_pix_models = None # Use the bad pixel models to perform a de-contaminated extraction. - kwargs = dict() + kwargs = {} kwargs['bad_pix'] = soss_kwargs['bad_pix'] kwargs['tracemodels'] = bad_pix_models result = _extract_image(data_to_extract, scierr, scimask, box_weights, **kwargs) diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index 57fae0c43a..6b11c08800 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -3,7 +3,6 @@ from functools import partial from scipy.sparse import csr_matrix from jwst.extract_1d.soss_extract import atoca -from jwst.extract_1d.soss_extract import atoca_utils as au from jwst.extract_1d.soss_extract.tests.conftest import ( SPECTRAL_SLOPE, f_lam, DATA_SHAPE, WAVE_BNDS_O1, WAVE_BNDS_O2) @@ -52,7 +51,6 @@ def test_extraction_engine_init( assert len(engine.i_bounds[order]) == 2 assert engine.i_bounds[order][0] >= 0 assert engine.i_bounds[order][1] < DATA_SHAPE[1] - assert all([isinstance(val, int) for val in engine.i_bounds[order]]) # test to ensure that wave_map is considered for bounds # in order 1 the wave_map is more restrictive on the shortwave end @@ -61,7 +59,7 @@ def test_extraction_engine_init( # in order 2 no restriction on the shortwave end so get the full extent assert engine.i_bounds[0][0] > 0 assert engine.i_bounds[1][1] < engine.n_wavepoints - + # TODO: off-by-one error here. why does this fail? # check what this looks like on a real run on main # assert engine.i_bounds[0][1] == engine.n_wavepoints From 4d6a71c420fdd2e5db0ac9bd2b3a6c781dc65c46 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Thu, 5 Dec 2024 12:27:54 -0500 Subject: [PATCH 28/35] starting tests for box extract --- jwst/extract_1d/soss_extract/soss_boxextract.py | 3 --- jwst/extract_1d/soss_extract/tests/test_atoca.py | 6 ++++++ jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py | 6 ++++++ 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py diff --git a/jwst/extract_1d/soss_extract/soss_boxextract.py b/jwst/extract_1d/soss_extract/soss_boxextract.py index 5adcf5bfb9..65d21124ab 100644 --- a/jwst/extract_1d/soss_extract/soss_boxextract.py +++ b/jwst/extract_1d/soss_extract/soss_boxextract.py @@ -127,9 +127,6 @@ def box_extract(scidata, scierr, scimask, box_weights): def estim_error_nearest_data(err, data, pix_to_estim, valid_pix): """ - TODO: how similar is this to other places where we interpolate errors? - Could this be replaced with some algorithm involving smoothing of the error map? - Function to estimate pixel error empirically using the corresponding error of the nearest pixel value (`data`). Intended to be used in a box extraction when the bad pixels are modeled. diff --git a/jwst/extract_1d/soss_extract/tests/test_atoca.py b/jwst/extract_1d/soss_extract/tests/test_atoca.py index 6b11c08800..021fc21cda 100644 --- a/jwst/extract_1d/soss_extract/tests/test_atoca.py +++ b/jwst/extract_1d/soss_extract/tests/test_atoca.py @@ -6,6 +6,12 @@ from jwst.extract_1d.soss_extract.tests.conftest import ( SPECTRAL_SLOPE, f_lam, DATA_SHAPE, WAVE_BNDS_O1, WAVE_BNDS_O2) +"""Tests for the ATOCA extraction engine, taking advantage of the miniature +model set up by conftest.py. +The test_call() function ensures that the engine can retrieve the spectrum +with SPECTRAL_SLOPE that we put into the data, which implicitly checks +a lot of the matrix math.""" + def test_extraction_engine_init( wave_map, diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py new file mode 100644 index 0000000000..3556e05bf2 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py @@ -0,0 +1,6 @@ +from jwst.extract_1d.soss_extract.soss_boxextract import get_box_weights, box_extract + + +def test_get_box_weights(): + + \ No newline at end of file From 81ffd5234c63fbecc94778318e0f4a842b7fc7bc Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 6 Dec 2024 16:52:31 -0500 Subject: [PATCH 29/35] added unit tests for boxextract functions --- .../extract_1d/soss_extract/tests/conftest.py | 42 ++++++- .../tests/test_soss_boxextract.py | 117 +++++++++++++++++- 2 files changed, 155 insertions(+), 4 deletions(-) diff --git a/jwst/extract_1d/soss_extract/tests/conftest.py b/jwst/extract_1d/soss_extract/tests/conftest.py index 2dd068c4d3..26af5743e8 100644 --- a/jwst/extract_1d/soss_extract/tests/conftest.py +++ b/jwst/extract_1d/soss_extract/tests/conftest.py @@ -1,5 +1,6 @@ import pytest import numpy as np +from scipy.signal import savgol_filter from functools import partial from jwst.extract_1d.soss_extract import atoca from jwst.extract_1d.soss_extract import atoca_utils as au @@ -26,6 +27,7 @@ WAVE_BNDS_GRID = [0.7, 2.7] ORDER1_SCALING = 20.0 ORDER2_SCALING = 2.0 +TRACE_END_IDX = [DATA_SHAPE[1],180] SPECTRAL_SLOPE = 2 @@ -37,7 +39,7 @@ def wave_map(): wave_ord2 = np.linspace(WAVE_BNDS_O2[0], WAVE_BNDS_O2[1], DATA_SHAPE[1]) wave_ord2 = np.ones(DATA_SHAPE)*wave_ord2[np.newaxis,:] # add a small region of zeros to mimic what is input to the step from ref files - wave_ord2[:,190:] = 0.0 + wave_ord2[:,TRACE_END_IDX[1]:] = 0.0 return [wave_ord1, wave_ord2] @@ -70,6 +72,41 @@ def trace_profile(wave_map): return [profile_ord1, profile_ord2] +@pytest.fixture(scope="package") +def trace1d(wave_map, trace_profile): + """For each order, return tuple (xtrace, ytrace, wavetrace)""" + + trace_list = [] + for order in [0,1]: + + profile = trace_profile[order] + wave2d = wave_map[order].copy() #avoid modifying wave_map, it's needed elsewhere! + end_idx = TRACE_END_IDX[order] + + # find mean y-index at each x where trace_profile is nonzero + shp = profile.shape + xx, yy = np.mgrid[:shp[0], :shp[1]] + xx = xx.astype("float") + xx[profile==0] = np.nan + mean_trace = np.nanmean(xx, axis=0) + # same strategy for wavelength + wave2d[profile==0] = np.nan + mean_wave = np.nanmean(wave2d, axis=0) + + # smooth it. we know it should be linear so use 1st order poly and large box size + ytrace = savgol_filter(mean_trace, mean_trace.size-1, 1) + wavetrace = savgol_filter(mean_wave, mean_wave.size-1, 1) + + # apply cutoff + wavetrace = wavetrace[:end_idx] + ytrace = ytrace[:end_idx] + xtrace = np.arange(0, end_idx) + + trace_list.append((xtrace, ytrace, wavetrace)) + + return trace_list + + @pytest.fixture(scope="package") def wave_grid(): """wave_grid has smaller spacings in some places than others @@ -226,4 +263,7 @@ def imagemodel(engine, detector_mask): # to avoid very large values of data / error error = noise_scaling*(rng.standard_normal(shp)**2 + 0.5) + # TODO: Why does the data here have some kind of beat frequency? + # TODO: why does the data here have one deep negative bar at the end of each spectral order? + return data, error \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py index 3556e05bf2..fae441f9df 100644 --- a/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py +++ b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py @@ -1,6 +1,117 @@ -from jwst.extract_1d.soss_extract.soss_boxextract import get_box_weights, box_extract +import pytest +import numpy as np +from .conftest import DATA_SHAPE +from jwst.extract_1d.soss_extract.soss_boxextract import ( + get_box_weights, box_extract, estim_error_nearest_data) -def test_get_box_weights(): - \ No newline at end of file +WIDTH = 5.1 + +@pytest.fixture() +def box_weights(trace1d): + + weights_list = [] + for order in [0,1]: + tracex, tracey, wavetrace = trace1d[order] + weights_list.append(get_box_weights(tracey, WIDTH, DATA_SHAPE, tracex)) + return weights_list + + +def test_get_box_weights(trace1d, box_weights): + """ + Order 1 is easy because tracex.size and tracey.size are equal to data_shape[1] + Order 2 tests the case where they are not equal + """ + + for order in [0,1]: + + tracex, tracey, wavetrace = trace1d[order] + weights = box_weights[order] + + # check weights are between zero and 1 + assert weights.shape == DATA_SHAPE + assert np.max(weights) == 1 + assert np.min(weights) == 0 + + weight_sum = np.sum(weights, axis=0) + # some weights are zero because of the trace profile cutoff in order 2 + assert np.sum(weight_sum == 0) == (DATA_SHAPE[1] - tracey.size) + + # check sum of weights across y-axis is width + weight_sum = weight_sum[np.nonzero(weight_sum)] + assert np.allclose(weight_sum, WIDTH) + + # check at least some partial weights exist + assert np.sum(weight_sum == 1) < weight_sum.size + + # TODO: need a test, maybe regtest, for subarray sub96 problem + # see https://github.com/spacetelescope/jwst/issues/8780 + + +def test_box_extract(trace1d, box_weights, imagemodel): + + data, err = imagemodel + mask = np.isnan(data) + + for order in [0,1]: + weights = box_weights[order] + cols, flux, flux_err, npix = box_extract(data, err, mask, weights) + + # test that cols just represents the data + assert np.allclose(cols, np.arange(data.shape[1])) + + # test flux and flux_err are NaN where order 2 is cut off, but have good values elsewhere + xtrace = trace1d[order][0] + for f in [flux, flux_err]: + assert np.sum(~np.isnan(flux)) == xtrace.size + # test npix is zero there too + assert np.count_nonzero(npix) == xtrace.size + + # test that most of npix are equal to width (Not all, because of NaN mask, but NaN fraction) + # is small enough that it should still be the most represented count for such a small width + unique, counts = np.unique(npix, return_counts=True) + assert np.isclose(unique[np.argmax(counts)], WIDTH) + + # TODO: somehow check the fluxes retrieved look like what we would expect from data + # although this is hard because the wavelengths are not extracted here + + # TODO: why does flux have very low values at edges even after cutting by good? + + +def test_estim_error_nearest_data(imagemodel, mask_trace_profile): + + data, err = imagemodel + + for order in [0, 1]: + # this has bad pixels set to 1, ONLY within the spectral trace. + # everything else is zero, i.e., regions outside trace and good data + pix_to_estim = np.zeros(data.shape, dtype="bool") + pix_to_estim[np.isnan(data)] = 1 + pix_to_estim[mask_trace_profile[order] == 1] = 0 + + # this has bad pixels set to 0, and regions outside trace set to 0, and good data 1 + valid_pix = ~mask_trace_profile[order] + valid_pix[pix_to_estim] = 0 + + err_out = estim_error_nearest_data(err, data, pix_to_estim, valid_pix) + + # test that all replaced values are positive and no NaNs are left + assert np.sum(np.isnan(err_out)) == 0 + assert np.all(err_out > 0) + + # test that the replaced pixels are not statistical outliers c.f. the other pixels + replaced_pix = err_out[pix_to_estim] + original_pix = err_out[valid_pix] + + # diff = np.mean(replaced_pix)/np.mean(original_pix) + # assert np.isclose(diff, 1, rtol=0.5) # assert False + # TODO: why does this fail? + # In both orders, the errors on the replaced pixels are roughly + # half of the errors on the original good pixels + # There are enough replaced pixels here (~30) that small-number statistics cannot account for this + # The reason is because the code chooses the lower error between the two nearest-flux + # data points, and since the errors in our tests are uncorrelated with the flux values, + # this leads to a factor-of-2 decrease + # It's not clear to me that picking the smaller of two error values is the right thing to do + # but that behavior is documented From ba52f91ddca0637de3e37c6acc5b15d483a291c0 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 10 Dec 2024 13:30:48 -0500 Subject: [PATCH 30/35] added test coverage for pastasoss helper functions --- jwst/extract_1d/soss_extract/pastasoss.py | 74 ++++++-------- .../extract_1d/soss_extract/tests/conftest.py | 42 +++++++- .../soss_extract/tests/test_pastasoss.py | 99 +++++++++++++++++++ .../tests/test_soss_boxextract.py | 6 +- 4 files changed, 173 insertions(+), 48 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/test_pastasoss.py diff --git a/jwst/extract_1d/soss_extract/pastasoss.py b/jwst/extract_1d/soss_extract/pastasoss.py index 7751294243..7b4b843d67 100644 --- a/jwst/extract_1d/soss_extract/pastasoss.py +++ b/jwst/extract_1d/soss_extract/pastasoss.py @@ -199,7 +199,7 @@ def get_poly_features(x, offset): return wavelengths -def _rotate(x, y, angle, origin=(0, 0), interp=True): +def _rotate(x, y, angle, origin=(0, 0)): """ Applies a rotation transformation to a set of 2D points. @@ -216,9 +216,6 @@ def _rotate(x, y, angle, origin=(0, 0), interp=True): The angle (in degrees) by which to rotate the points. origin : Tuple[float, float], optional The point about which to rotate the points. Default is (0, 0). - interp : bool, optional - Whether to interpolate the rotated positions onto the original x-pixel - column values. Default is True. Returns ------- @@ -244,17 +241,15 @@ def _rotate(x, y, angle, origin=(0, 0), interp=True): # apply transformation x_new, y_new = R @ (xy - xy_center) + xy_center - # interpolate rotated positions onto x-pixel column values (default) - if interp: - # interpolate new coordinates onto original x values and mask values - # outside of the domain of the image 0<=x<=2047 and 0<=y<=255. - y_new = interp1d(x_new, y_new, fill_value="extrapolate")(x) - mask = np.where(y_new <= 255.0) - x = x[mask] - y_new = y_new[mask] - return x, y_new + # interpolate rotated positions onto x-pixel column values + # interpolate new coordinates onto original x values and mask values + # outside of the domain of the image 0<=x<=2047 and 0<=y<=255. + y_new = interp1d(x_new, y_new, fill_value="extrapolate")(x) + mask = np.where(y_new <= 255.0) + x = x[mask] + y_new = y_new[mask] + return x, y_new - return x_new, y_new def _find_spectral_order_index(refmodel, order): @@ -274,6 +269,10 @@ def _find_spectral_order_index(refmodel, order): The index to provide the reference file lists of traces and wavecal models to retrieve the arrays for the desired spectral order """ + if order not in [1,2]: + error_message = f"Order {order} is not supported at this time." + log.error(error_message) + raise ValueError(error_message) for i, entry in enumerate(refmodel.traces): if entry.spectral_order == order: @@ -283,7 +282,7 @@ def _find_spectral_order_index(refmodel, order): return -1 -def _get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): +def _get_soss_traces(refmodel, pwcpos, order, subarray): """Generate the traces given a pupil wheel position. This is the primary method for generating the gr700xd trace position given a @@ -301,14 +300,11 @@ def _get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): pwcpos : float The pupil wheel positions angle provided in the FITS header under keyword PWCPOS. - order : str + order : str or int The spectral order for which a trace is computed. Order 3 is currently unsupported. subarray : str Name of subarray in use, typically 'SUBSTRIP96' or 'SUBSTRIP256'. - interp : bool, optional - Whether to interpolate the rotated positions onto the original x-pixel - column values. Default is True. Returns ------- @@ -317,36 +313,28 @@ def _get_soss_traces(refmodel, pwcpos, order, subarray, interp=True): points for the first spectral order. If `order` is '2', a tuple of the x and y coordinates of the rotated points for the second spectral order. - If `order` is '3' or a combination of '1', '2', and '3', a list of - tuples of the x and y coordinates of the rotated points for each - spectral order. Raises ------ ValueError - If `order` is not '1', '2', '3', or a combination of '1', '2', and '3'. + If `order` is not in ['1', '2']. """ spectral_order_index = _find_spectral_order_index(refmodel, int(order)) - if spectral_order_index < 0: - error_message = f"Order {order} is not supported at this time." - log.error(error_message) - raise ValueError(error_message) - else: - # reference trace data - x, y = refmodel.traces[spectral_order_index].trace.T.copy() - origin = refmodel.traces[spectral_order_index].pivot_x, refmodel.traces[spectral_order_index].pivot_y + # reference trace data + x, y = refmodel.traces[spectral_order_index].trace.T.copy() + origin = refmodel.traces[spectral_order_index].pivot_x, refmodel.traces[spectral_order_index].pivot_y - # Offset for SUBSTRIP96 - if subarray == 'SUBSTRIP96': - y -= 10 - # rotated reference trace - x_new, y_new = _rotate(x, y, pwcpos - refmodel.meta.pwcpos_cmd, origin, interp=interp) + # Offset for SUBSTRIP96 + if subarray == 'SUBSTRIP96': + y -= 10 + # rotated reference trace + x_new, y_new = _rotate(x, y, pwcpos - refmodel.meta.pwcpos_cmd, origin) - # wavelength associated to trace at given pwcpos value - wavelengths = _get_wavelengths(refmodel, x_new, pwcpos, int(order)) + # wavelength associated to trace at given pwcpos value + wavelengths = _get_wavelengths(refmodel, x_new, pwcpos, int(order)) - return order, x_new, y_new, wavelengths + return order, x_new, y_new, wavelengths def _extrapolate_to_wavegrid(w_grid, wavelength, quantity): @@ -386,9 +374,7 @@ def _extrapolate_to_wavegrid(w_grid, wavelength, quantity): q = np.concatenate((q_left, q, q_right)) # resample at the w_grid everywhere - q_grid = np.interp(w_grid, w, q) - - return q_grid + return np.interp(w_grid, w, q) def _calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=0, maxiter=5, dtol=1e-2): @@ -489,8 +475,8 @@ def get_soss_wavemaps(refmodel, pwcpos, subarray, padding=False, padsize=0, spec Array, Array The 2D wavemaps and corresponding 1D spectraces """ - _, order1_x, order1_y, order1_wl = _get_soss_traces(refmodel, pwcpos, order='1', subarray=subarray, interp=True) - _, order2_x, order2_y, order2_wl = _get_soss_traces(refmodel, pwcpos, order='2', subarray=subarray, interp=True) + _, order1_x, order1_y, order1_wl = _get_soss_traces(refmodel, pwcpos, order='1', subarray=subarray) + _, order2_x, order2_y, order2_wl = _get_soss_traces(refmodel, pwcpos, order='2', subarray=subarray) # Make wavemap from trace center wavelengths, padding to shape (296, 2088) wavemin = WAVEMAP_WLMIN diff --git a/jwst/extract_1d/soss_extract/tests/conftest.py b/jwst/extract_1d/soss_extract/tests/conftest.py index 26af5743e8..4186336b55 100644 --- a/jwst/extract_1d/soss_extract/tests/conftest.py +++ b/jwst/extract_1d/soss_extract/tests/conftest.py @@ -4,6 +4,7 @@ from functools import partial from jwst.extract_1d.soss_extract import atoca from jwst.extract_1d.soss_extract import atoca_utils as au +from stdatamodels.jwst.datamodels import PastasossModel """ Create a miniature, slightly simplified model of the SOSS detector/optics. @@ -21,6 +22,7 @@ - Kernel set to unity (for now) """ +PWCPOS = 245.85932900002442 DATA_SHAPE = (25,200) WAVE_BNDS_O1 = [2.8, 0.8] WAVE_BNDS_O2 = [1.4, 0.5] @@ -266,4 +268,42 @@ def imagemodel(engine, detector_mask): # TODO: Why does the data here have some kind of beat frequency? # TODO: why does the data here have one deep negative bar at the end of each spectral order? - return data, error \ No newline at end of file + return data, error + + +@pytest.fixture(scope="module") +def refmodel(trace1d): + """Mock Pastasoss reference model with spatial dimensions scaled + down by a factor of 10. Since the traces are just linear, the polynomials + also have coefficients equal to 0 except for the constant and linear terms""" + model = PastasossModel() + model.meta.pwcpos_cmd = 245.76 + + trace0 = {"pivot_x": 189.0, + "pivot_y": 5.0, + "spectral_order": 1, + "trace": np.array([trace1d[0][0], trace1d[0][1]], dtype=np.float64).T, + "padding": 0,} + trace1 = {"pivot_x": 168.0, + "pivot_y": 20.0, + "spectral_order": 2, + "trace": np.array([trace1d[1][0], trace1d[1][1]], dtype=np.float64).T,} + model.traces = [trace0, trace1] + + wavecal0 = {"coefficients": [WAVE_BNDS_O1[0], -2.0,] + [0.0 for i in range(19)], + "polynomial_degree": 5, + "scale_extents": [[0, -1.03552000e-01], [DATA_SHAPE[1], 1.62882080e-01]]} + wavecal1 = {"coefficients": [WAVE_BNDS_O2[0], -1.0,] + [0.0 for i in range(8)], + "polynomial_degree": 3, + "scale_extents": [[0, 245.5929], [DATA_SHAPE[1], 245.9271]]} + model.wavecal_models = [wavecal0, wavecal1] + + thru0 = {"spectral_order": 1, + "wavelength": np.linspace(0.5, 5.5, 501), + "throughput": np.ones((501,)),} #peaks around 1.22 at value 0.37 + thru1 = {"spectral_order": 2, + "wavelength": np.linspace(0.5, 5.5, 501), + "throughput": np.ones((501,)),} #peaks around 0.7 at value of 0.16 + model.throughputs = [thru0, thru1] + + return model \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_pastasoss.py b/jwst/extract_1d/soss_extract/tests/test_pastasoss.py new file mode 100644 index 0000000000..e46bae3119 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_pastasoss.py @@ -0,0 +1,99 @@ +import pytest +import numpy as np + +from jwst.extract_1d.soss_extract.pastasoss import ( + _get_wavelengths, _find_spectral_order_index, _get_soss_traces, _extrapolate_to_wavegrid, +) + +from .conftest import TRACE_END_IDX, PWCPOS, WAVE_BNDS_O1, WAVE_BNDS_O2 + + +"""Test coverage for the helper functions in pastasoss.py""" + +def test_wavecal_models(refmodel): + + wave_bnds = [WAVE_BNDS_O1, WAVE_BNDS_O2] + for order in [1,2]: + idx = order-1 + bnds = wave_bnds[idx] + x = np.arange(0, TRACE_END_IDX[idx]+1) + wavelengths = _get_wavelengths(refmodel, x, PWCPOS, order) + + # check shapes + assert wavelengths.shape == x.shape + assert np.isclose(wavelengths[0], bnds[0]) + assert np.isclose(wavelengths[-1], bnds[1]) + + # ensure unique and descending + diff = wavelengths[1:] - wavelengths[:-1] + assert np.all(diff < 0) + + +def test_rotate(): + + # TODO: add meaningful tests of rotate + pass + + +def test_find_spectral_order_index(refmodel): + """TODO: why doesn't this raise an error when order is not recognized? + Surely it's a bad idea to have the index set to -1?""" + for order in [1,2]: + idx = _find_spectral_order_index(refmodel, order) + assert idx == order-1 + + for order in [0, "bad", None]: + with pytest.raises(ValueError): + _find_spectral_order_index(refmodel, order) + + +def test_get_soss_traces(refmodel): + + for order in ["1","2"]: + idx = int(order)-1 + for subarray in ["SUBSTRIP96", "SUBSTRIP256"]: + order_out, x_new, y_new, wavelengths = _get_soss_traces( + refmodel, + PWCPOS, + order, + subarray) + + assert str(order_out) == order + # since always interpolated back to original x, x_new should equal x + x_in, y_in = refmodel.traces[idx].trace.T.copy() + assert np.allclose(x_new, x_in) + # and wavelengths are same as what you get from _get_wavelengths on x_in + wave_expected = _get_wavelengths(refmodel, x_in, PWCPOS, int(order)) + assert np.allclose(wavelengths, wave_expected) + + # the y coordinate is the tricky one. it was rotated by pwcpos - refmodel.meta.pwcpos_cmd + # about pivot_x, pivot_y + assert y_new.shape == wavelengths.shape + # TODO: add meaningful tests of y + + +def test_extrapolate_to_wavegrid(refmodel): + + wavemin = 0.5 + wavemax = 5.5 + nwave = 501 + wave_grid = np.linspace(wavemin, wavemax, nwave) + + # only test first order + x = np.arange(0, TRACE_END_IDX[0]+1) + wl = _get_wavelengths(refmodel, x, PWCPOS, 1) + + # first ensure test setup gives all wl in wave_grid + # floating-point precision issues make np.around calls necessary + assert np.all(np.isin(np.around(wl,5), np.around(wave_grid,5))) + + x_extrap = _extrapolate_to_wavegrid(wave_grid, wl, x) + assert x_extrap.shape == wave_grid.shape + + # test that all x in x_extrap + assert np.all(np.isin(np.around(x,5), np.around(x_extrap,5))) + + # test extrapolated slope is same as input slope, since these are linear + m_extrap = (x_extrap[-1] - x_extrap[0])/(wave_grid[-1] - wave_grid[0]) + m = (x[-1] - x[0])/(wl[-1] - wl[0]) + assert np.isclose(m_extrap, m) diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py index fae441f9df..da6f074cc7 100644 --- a/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py +++ b/jwst/extract_1d/soss_extract/tests/test_soss_boxextract.py @@ -101,11 +101,11 @@ def test_estim_error_nearest_data(imagemodel, mask_trace_profile): assert np.all(err_out > 0) # test that the replaced pixels are not statistical outliers c.f. the other pixels - replaced_pix = err_out[pix_to_estim] - original_pix = err_out[valid_pix] - + # replaced_pix = err_out[pix_to_estim] + # original_pix = err_out[valid_pix] # diff = np.mean(replaced_pix)/np.mean(original_pix) # assert np.isclose(diff, 1, rtol=0.5) # assert False + # TODO: why does this fail? # In both orders, the errors on the replaced pixels are roughly # half of the errors on the original good pixels From 1359e05d842e465766d6bce7f7fca477a8702d27 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 10 Dec 2024 16:38:07 -0500 Subject: [PATCH 31/35] added unit tests for soss_syscor functions --- jwst/extract_1d/soss_extract/soss_extract.py | 1 + jwst/extract_1d/soss_extract/soss_syscor.py | 8 +-- .../soss_extract/tests/test_soss_syscor.py | 56 +++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/test_soss_syscor.py diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index ea24312d8e..50cb9fab0d 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -1054,6 +1054,7 @@ def run_extract1d(input_model, pastasoss_ref_name, refmask &= ~not_finite # Perform background correction. + print("bkg sub on?", soss_kwargs["subtract_background"]) if soss_kwargs['subtract_background']: log.info('Applying background subtraction.') bkg_mask = make_background_mask(scidata, width=40) diff --git a/jwst/extract_1d/soss_extract/soss_syscor.py b/jwst/extract_1d/soss_extract/soss_syscor.py index 0e0ac47c46..45589c851c 100644 --- a/jwst/extract_1d/soss_extract/soss_syscor.py +++ b/jwst/extract_1d/soss_extract/soss_syscor.py @@ -78,11 +78,9 @@ def make_background_mask(deepstack, width): # Set the appropriate quantile for masking based on the subarray size. if nrows == 96: # SUBSTRIP96 - quantile = 100 * (1 - width / 96) # Mask 1 order worth of pixels. - elif nrows == 256: # SUBSTRIP256 - quantile = 100 * (1 - 2 * width / 256) # Mask 2 orders worth of pixels. - elif nrows == 2048: # FULL - quantile = 100 * (1 - 2 * width / 2048) # Mask 2 orders worth of pixels. + quantile = 100 * (1 - width / nrows) # Mask 1 order worth of pixels. + elif nrows in [256, 2048]: # SUBSTRIP256, FULL + quantile = 100 * (1 - 2 * width / nrows) # Mask 2 orders worth of pixels. else: msg = (f'Unexpected image dimensions, expected nrows = 96, 256 or 2048, ' f'got nrows = {nrows}.') diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_syscor.py b/jwst/extract_1d/soss_extract/tests/test_soss_syscor.py new file mode 100644 index 0000000000..6576c878c3 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_soss_syscor.py @@ -0,0 +1,56 @@ +import pytest +import numpy as np + +from jwst.extract_1d.soss_extract.soss_syscor import ( + soss_background, + make_background_mask, +) + + +def test_soss_background(imagemodel, detector_mask, mask_trace_profile): + + data, err = imagemodel + bkg_mask = ~mask_trace_profile[0] | ~mask_trace_profile[1] | detector_mask + + data_bkg, col_bkg = soss_background(data, detector_mask, bkg_mask) + assert data_bkg.shape == data.shape + assert col_bkg.size == data.shape[1] + + # check background now has mean zero + mean_bkg = np.mean(data_bkg[~bkg_mask]) + assert np.isclose(mean_bkg, 0.0) + + # check col_bkg are at least close to the non-sigma-clipped version which is much easier to calculate + # For the test case, there are no outliers so this should be quite a close match + data[bkg_mask] = np.nan + col_bkg_unclipped = np.nanmean(data, axis=0) + assert np.allclose(col_bkg, col_bkg_unclipped) + + +def test_make_background_mask(): + + rng = np.random.default_rng(seed=42) + for sub in [96, 256, 2048]: + + shape = (sub, 2048) + width = int(sub/4) + data = rng.normal(0.0, 1.0, shape) + + mask = make_background_mask(data, width) + + if sub == 96: + expected_bad_frac = 1/4 + else: + expected_bad_frac = 1/2 + + bad_frac = np.sum(mask)/mask.size + # test that bad fraction is computed properly for all modes + assert np.isclose(bad_frac, expected_bad_frac) + + # test that mask=True is the high-flux pixels + assert np.mean(data[mask]) > np.mean(data) + + # test unrecognized shape + with pytest.raises(ValueError): + data = rng.normal(0.0, 1.0, (40, 2048)) + make_background_mask(data, width) From 3ae226d8c079069fc0b5b4e4fd3c042bd7b2f5a3 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 11 Dec 2024 10:14:14 -0500 Subject: [PATCH 32/35] added test for estim_flux_first_order --- jwst/extract_1d/soss_extract/soss_extract.py | 12 ++---- .../soss_extract/tests/test_soss_extract.py | 37 +++++++++++++++++++ 2 files changed, 41 insertions(+), 8 deletions(-) create mode 100644 jwst/extract_1d/soss_extract/tests/test_soss_extract.py diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 50cb9fab0d..4bc4942947 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -206,21 +206,18 @@ def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_tr wave_grid = grid_from_map(wave_maps[0], spat_pros[0], n_os=1) # Mask parts contaminated by order 2 based on its spatial profile - mask = ((spat_pros[1] >= threshold) | mask_trace_profile | scimask) + mask = ((spat_pros[1] >= threshold) | mask_trace_profile[0]) # Init extraction without convolution kernel (so extract the spectrum at order 1 resolution) ref_file_args = [wave_maps[0]], [spat_pros[0]], [thrpts[0]], [None] - kwargs = {'orders': [1],} - engine = ExtractionEngine(*ref_file_args, wave_grid, [mask], **kwargs) + engine = ExtractionEngine(*ref_file_args, wave_grid, [mask], global_mask=scimask, orders=[1]) # Extract estimate spec_estimate = engine(scidata_bkg, scierr) # Interpolate idx = np.isfinite(spec_estimate) - estimate_spl = UnivariateSpline(wave_grid[idx], spec_estimate[idx], k=3, s=0, ext=0) - - return estimate_spl + return UnivariateSpline(wave_grid[idx], spec_estimate[idx], k=3, s=0, ext=0) def _get_native_grid_from_trace(ref_files, spectral_order): @@ -550,7 +547,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Dev suggested np.logspace(-19, -10, 10) if (tikfac is None or wave_grid is None) and estimate is None: estimate = _estim_flux_first_order(scidata_bkg, scierr, scimask, - ref_file_args, mask_trace_profile[0]) + ref_file_args, mask_trace_profile) # Generate grid based on estimate if not given if wave_grid is None: @@ -1054,7 +1051,6 @@ def run_extract1d(input_model, pastasoss_ref_name, refmask &= ~not_finite # Perform background correction. - print("bkg sub on?", soss_kwargs["subtract_background"]) if soss_kwargs['subtract_background']: log.info('Applying background subtraction.') bkg_mask = make_background_mask(scidata, width=40) diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_extract.py b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py new file mode 100644 index 0000000000..7fde43b972 --- /dev/null +++ b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py @@ -0,0 +1,37 @@ +import pytest +import numpy as np +from numpy.polynomial import Polynomial + +from jwst.extract_1d.soss_extract.soss_extract import ( + _estim_flux_first_order, +) +from .conftest import SPECTRAL_SLOPE + + +def test_estim_flux_first_order(imagemodel, + detector_mask, + wave_map, + trace_profile, + throughput, + mask_trace_profile): + + data, err = imagemodel + ref_file_args = [wave_map, trace_profile, throughput, None] + + func = _estim_flux_first_order(data, + err, + detector_mask, + ref_file_args, + mask_trace_profile, + threshold=1e-4) + + # use the function to generate a test spectrum + test_wl = np.linspace(1.0, 2.5, 100) + test_flux = func(test_wl) + + # check slope against expected value to within a few percent + p_fitted = Polynomial.fit(test_wl, test_flux, 1) + b, m = p_fitted.coef + assert np.isclose(m, SPECTRAL_SLOPE, rtol=0.05) + + From 9bb45abb1cd72b94bb509ece255290e14d99f077 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Thu, 12 Dec 2024 16:36:29 -0500 Subject: [PATCH 33/35] Added unit tests and supporting fixtures for model_image testing --- jwst/extract_1d/soss_extract/soss_extract.py | 133 ++++++++++++-- .../extract_1d/soss_extract/tests/conftest.py | 12 +- .../soss_extract/tests/test_soss_extract.py | 172 +++++++++++++++++- 3 files changed, 299 insertions(+), 18 deletions(-) diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 4bc4942947..0bda97da51 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -177,6 +177,9 @@ def _get_trace_1d(ref_files, order): def _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile, threshold=1e-4): """ + Roughly estimate the underlying flux of the target spectrum by simply masking + out order 2 and retrieving the flux from order 1. + Parameters ---------- scidata_bkg : array @@ -264,9 +267,9 @@ def _get_native_grid_from_trace(ref_files, spectral_order): def _get_grid_from_trace(ref_files, spectral_order, n_os): """ - TODO: is this partially or fully redundant with atoca_utils.grid_from_map? Make a 1d-grid of the pixels boundary and ready for ATOCA ExtractionEngine, based on the wavelength solution. + Parameters ---------- ref_files: dict @@ -298,8 +301,7 @@ def _get_grid_from_trace(ref_files, spectral_order, n_os): def _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os): - ''' - TODO: add docstring + """ Create the grid to use for the simultaneous extraction of order 1 and 2. The grid is made by: 1) requiring that it satisfies the oversampling n_os @@ -307,7 +309,26 @@ def _make_decontamination_grid(ref_files, rtol, max_grid_size, estimate, n_os): 3) trying to reach the specified tolerance in the rest of spectral range The max_grid_size overrules steps 2) and 3), so the precision may not be reached if the grid size needed is too large. - ''' + + Parameters + ---------- + ref_files : dict + A dictionary of the reference file DataModels. + rtol : float + The relative tolerance needed on a pixel model. + max_grid_size : int + Maximum grid size allowed. + estimate : UnivariateSpline + Estimate of the target flux as a function of wavelength in microns. + n_os : int + The oversampling factor of the wavelength grid used when solving for + the uncontaminated flux. + + Returns + ------- + wave_grid : 1d array + The grid of the pixels boundaries at the native sampling. + """ # Build native grid for each orders. spectral_orders = [2, 1] @@ -367,7 +388,36 @@ def _populate_tikho_attr(spec, tiktests, idx, sp_ord): def _f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord): + """ + TODO: would it be better to pass the engine directly in here somehow? + + Bin the flux to the pixel grid and build a SpecModel. + + Parameters + ---------- + f_order : np.array + The solution f_k of the linear system. + + grid_order : np.array + The wavelength grid of the solution, usually oversampled compared to the pixel grid. + + ref_file_args : list + The reference file arguments used by the ExtractionEngine. + + pixel_grid : np.array + The pixel grid to which the flux should be binned. + mask : np.array + The mask of the pixels to be extracted. + + sp_ord : int + The spectral order of the flux. + + Returns + ------- + spec : SpecModel + + """ # Make sure the input is not modified ref_file_args = ref_file_args.copy() @@ -383,6 +433,7 @@ def _f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord): pixel_grid = np.squeeze(pixel_grid) f_binned = np.squeeze(f_binned) + # Remove Nans to save space is_valid = np.isfinite(f_binned) table_size = np.sum(is_valid) @@ -499,15 +550,18 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, to model each pixel valid pixel of the detector. If not given, the grid is determined based on an estimate of the flux (estimate), the relative tolerance (rtol) required on each pixel model and the maximum grid size (max_grid_size). + # TODO: none of the options work except None or an ndarray, including on main. + # docstring needs updates. + # Should we add support for these? If not, is SossWaveGridModel used for anything, + # and can that be removed from stdatamodels and as a valid argument to soss_wave_grid_in? estimate : UnivariateSpline or None Estimate of the target flux as a function of wavelength in microns. rtol : float The relative tolerance needed on a pixel model. It is used to determine the sampling - of the soss_wave_grid when not directly given. Default is 1e-3. + of wave_grid when the input wave_grid is None. Default is 1e-3. max_grid_size : int - Maximum grid size allowed. It is used when soss_wave_grid is not directly - to make sure the computation time or the memory used stays reasonable. - Default is 1000000 + Maximum grid size allowed when wave_grid is None. + Default is 1000000. Returns ------- @@ -542,9 +596,6 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, global_mask = np.all(mask_trace_profile, axis=0).astype(bool) # Rough estimate of the underlying flux - # Note: estim_flux func is not strictly necessary and factors could be a simple logspace - - # dq mask caused issues here and this may need a try/except wrap. - # Dev suggested np.logspace(-19, -10, 10) if (tikfac is None or wave_grid is None) and estimate is None: estimate = _estim_flux_first_order(scidata_bkg, scierr, scimask, ref_file_args, mask_trace_profile) @@ -683,7 +734,9 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, model = np.where(already_modeled, 0., model) # Add to tracemodels + both_nan = np.isnan(tracemodels[order_str]) & np.isnan(model) tracemodels[order_str] = np.nansum([tracemodels[order_str], model], axis=0) + tracemodels[order_str][both_nan] = np.nan # Add the result to spec_list for sp in spec_ord: @@ -753,10 +806,63 @@ def _decontaminate_image(scidata_bkg, tracemodels, subarray): return decontaminated_data -# TODO Add docstring def _model_single_order(data_order, err_order, ref_file_args, mask_fit, mask_rebuild, order, wave_grid, valid_cols, tikfac_log_range, save_tiktests=False): + """ + Extract an output spectrum for a single spectral order using the ATOCA + algorithm, testing a range of Tikhonov factors. + The Tikhonov factor is derived in two stages: first, ten factors are tested + spanning tikfac_log_range, and then a further 20 factors are tested across + 2 orders of magnitude in each direction around the best factor from the first + stage. + The best-fitting model and spectrum are reconstructed using the best-fit Tikhonov factor + and respecting mask_rebuild. + + Parameters + ---------- + data_order : np.array + The 2D data array for the spectral order to be extracted. + err_order : np.array + The 2D error array for the spectral order to be extracted. + ref_file_args : list + The reference file arguments used by the ExtractionEngine. + mask_fit : np.array + Mask determining the aperture used for extraction. This typically includes + detector bad pixels and any pixels that are not part of the trace + mask_rebuild : np.array + Mask determining the aperture used for rebuilding the trace. This typically includes + only pixels that do not belong to either spectral trace, i.e., regions of the detector + where no real data could exist. + order : int + The spectral order to be extracted. + wave_grid : np.array + The wavelength grid used to model the data. + valid_cols : np.array + The columns of the detector that are valid for extraction. + tikfac_log_range : list + The range of Tikhonov factors to test, in log space. + save_tiktests : bool, optional. + If True, save the intermediate models and spectra for each Tikhonov factor tested. + + Returns + ------- + model : np.array + Model derived from the best Tikhonov factor, same shape as data_order. + spec_list : list of SpecModel + If save_tiktests is True, returns a list of the model spectra for each Tikhonov factor tested, + with the best-fitting spectrum last in the list. + If save_tiktests is False, returns a one-element list with the best-fitting spectrum. + + Notes + ----- + The last spectrum in the list of SpecModels lacks the "chi2", "chi2_soft_l1", "chi2_cauchy", and "reg" + attributes, as these are only calculated for the intermediate models. The last spectrum is not + necessarily identical to any of the spectra in the list, as it is reconstructed according to + mask_rebuild instead of fit respecting mask_fit. + + # TODO are all of these behaviors for the last spec in the list the desired ones? + """ # The throughput and kernel is not needed here; set them so they have no effect on the extraction. def throughput(wavelength): @@ -777,7 +883,7 @@ def throughput(wavelength): # Find the tikhonov factor. # Initial pass with tikfac_range. - factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1] + 8, 10) + factors = np.logspace(tikfac_log_range[0], tikfac_log_range[-1], 10) all_tests = engine.get_tikho_tests(factors, data_order, err_order) tikfac = engine.best_tikho_factor(tests=all_tests, fit_mode='all') @@ -821,7 +927,6 @@ def throughput(wavelength): # Add the result to spec_list spec_list.append(spec_ord) - return model, spec_list diff --git a/jwst/extract_1d/soss_extract/tests/conftest.py b/jwst/extract_1d/soss_extract/tests/conftest.py index 4186336b55..e339bfa4ec 100644 --- a/jwst/extract_1d/soss_extract/tests/conftest.py +++ b/jwst/extract_1d/soss_extract/tests/conftest.py @@ -203,7 +203,7 @@ def mask_from_trace(trace_in, cut_low=None, cut_hi=None): @pytest.fixture(scope="package") -def detector_mask(wave_map): +def detector_mask(): """Add a few random bad pixels""" rng = np.random.default_rng(42) mask = np.zeros(DATA_SHAPE, dtype=bool) @@ -306,4 +306,12 @@ def refmodel(trace1d): "throughput": np.ones((501,)),} #peaks around 0.7 at value of 0.16 model.throughputs = [thru0, thru1] - return model \ No newline at end of file + return model + + +@pytest.fixture +def ref_files(refmodel): + ref_files = {"pastasoss": refmodel} + ref_files["subarray"] = "SUBSTRIP256" + ref_files["pwcpos"] = PWCPOS + return ref_files \ No newline at end of file diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_extract.py b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py index 7fde43b972..e439874db6 100644 --- a/jwst/extract_1d/soss_extract/tests/test_soss_extract.py +++ b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py @@ -1,11 +1,13 @@ +from functools import partial import pytest import numpy as np from numpy.polynomial import Polynomial +from stdatamodels.jwst.datamodels import SpecModel, SossWaveGridModel from jwst.extract_1d.soss_extract.soss_extract import ( - _estim_flux_first_order, + _estim_flux_first_order, _model_image, _compute_box_weights, ) -from .conftest import SPECTRAL_SLOPE +from .conftest import SPECTRAL_SLOPE, DATA_SHAPE def test_estim_flux_first_order(imagemodel, @@ -35,3 +37,169 @@ def test_estim_flux_first_order(imagemodel, assert np.isclose(m, SPECTRAL_SLOPE, rtol=0.05) +@pytest.fixture +def monkeypatch_setup(monkeypatch, + wave_map, + trace_profile, + throughput, + webb_kernels, + trace1d,): + """Monkeypatch get_ref_file_args and get_trace_1d to return the miniature model detector""" + + def mock_get_ref_file_args(wave, trace, thru, kern, reffiles): + """Return the arrays from conftest instead of querying CRDS""" + return [wave, trace, thru, kern] + + def mock_trace1d(trace, reffiles, order): + """Return the traces from conftest instead of doing math that requires a full-sized detector""" + return trace[int(order)-1] + + monkeypatch.setattr("jwst.extract_1d.soss_extract.soss_extract.get_ref_file_args", + partial(mock_get_ref_file_args, wave_map, trace_profile, throughput, webb_kernels)) + monkeypatch.setattr("jwst.extract_1d.soss_extract.soss_extract._get_trace_1d", + partial(mock_trace1d, trace1d)) + + +def test_model_image(monkeypatch_setup, + imagemodel, + detector_mask, + ref_files,): + scidata, scierr = imagemodel + + refmask = np.zeros_like(detector_mask) + box_width = 5.0 + box_weights, wavelengths = _compute_box_weights(ref_files, DATA_SHAPE, box_width) + + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=None, threshold=1e-4, n_os=2, wave_grid=None, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) + + # check output basics, types and shapes + assert len(tracemodels) == 2 + for order in tracemodels: + tm = tracemodels[order] + assert tm.dtype == np.float64 + assert tm.shape == DATA_SHAPE + # should be some nans in the trace model but not all + assert 0 < np.sum(np.isfinite(tm)) < tm.size + for x in [tikfac, logl]: + assert isinstance(x, float) + assert np.isfinite(x) + assert logl < 0 + assert wave_grid.dtype == np.float64 + for spec in spec_list: + assert isinstance(spec, SpecModel) + + factors = np.array([getattr(spec.meta.soss_extract1d, "factor", np.nan) for spec in spec_list]) + chi2s = np.array([getattr(spec.meta.soss_extract1d, "chi2", np.nan) for spec in spec_list]) + orders = np.array([spec.spectral_order for spec in spec_list]) + colors = np.array([spec.meta.soss_extract1d.color_range for spec in spec_list]) + + assert tikfac in factors + + # ensure outputs have the shapes we expect for each order and blue/red + n_good = [] + for order in [1,2]: + for color in ["RED", "BLUE"]: + good = (order == orders) & (color == colors) + # check that there's at least one good spectrum for all valid order-color combinations + if not np.any(good): + assert order == 1 + assert color == "BLUE" + continue + n_good.append(np.sum(good)) + + this_factors = factors[good] + this_chi2s = chi2s[good] + this_spec = np.array(spec_list)[good] + nochi = np.isnan(this_chi2s) + + # _model_single_order is set up so that the final/best spectrum is last in the list + # it lacks chi2 calculations + assert np.sum(nochi) == 1 + assert np.where(nochi)[0][0] == len(this_chi2s) - 1 + + # it represents the best tikhonov factor for that order-color combination + # which is not necessarily the same as the top-level tikfac for the blue part of order 2 + # but it is the same for the red part of order 1 and the red part of order 2 + if color == "RED": + assert this_factors[-1] == tikfac + + # check that the output spectra contain good data + for spec in this_spec: + spec = np.array([[s[0], s[1]] for s in spec.spec_table]) + assert np.sum(np.isfinite(spec)) == spec.size + + + # check that all order-color combinations have the same number of spectra + n_good = np.array(n_good) + assert np.all(n_good >= 1) + assert np.all(n_good - n_good[0] == 0) + + + # check that if tikfac is defined, output spectra is a single-element list + tikfac_in = 1e-7 + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=None, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) + + +def test_model_image_tikfac_specified(monkeypatch_setup, + imagemodel, + detector_mask, + ref_files,): + """Ensure spec_list is a single-element list per order if tikfac is specified""" + scidata, scierr = imagemodel + + refmask = np.zeros_like(detector_mask) + box_width = 5.0 + box_weights, wavelengths = _compute_box_weights(ref_files, DATA_SHAPE, box_width) + + tikfac_in = 1e-7 + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=None, + estimate=None, rtol=1e-3, max_grid_size=1000000, +) + assert len(spec_list) == 3 + assert tikfac == tikfac_in + + +def test_model_image_wavegrid_specified(monkeypatch_setup, + imagemodel, + detector_mask, + ref_files,): + """Ensure wave_grid is used if specified. + Also specify tikfac because it makes the code run faster to not have to re-derive it. + + Note the failure with SossWaveGridModel passed as input. What should be done about that? + """ + scidata, scierr = imagemodel + + refmask = np.zeros_like(detector_mask) + box_width = 5.0 + box_weights, wavelengths = _compute_box_weights(ref_files, DATA_SHAPE, box_width) + + tikfac_in = 1e-7 + # test np.array input + wave_grid_in = np.linspace(1.0, 2.5, 100) + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=wave_grid_in, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) + assert np.allclose(wave_grid, wave_grid_in) + + # test SossWaveGridModel input + with pytest.raises(ValueError): + wave_grid_in = SossWaveGridModel() + wave_grid_in.wavegrid = np.linspace(1.0, 2.5, 100) + tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( + scidata, scierr, detector_mask, refmask, ref_files, box_weights, + tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=wave_grid_in, + estimate=None, rtol=1e-3, max_grid_size=1000000, + ) From 74b5719149c4dae8c7b68d224b11474046db09d9 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 16 Dec 2024 12:23:42 -0500 Subject: [PATCH 34/35] small start to docs --- docs/jwst/extract_1d/description.rst | 6 ++++ jwst/extract_1d/soss_extract/soss_extract.py | 28 ++++++++----------- .../soss_extract/tests/test_soss_extract.py | 11 ++------ 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/docs/jwst/extract_1d/description.rst b/docs/jwst/extract_1d/description.rst index bd31735236..0d1894ac19 100644 --- a/docs/jwst/extract_1d/description.rst +++ b/docs/jwst/extract_1d/description.rst @@ -383,3 +383,9 @@ the data must be given. The steps to run this correction outside the pipeline ar flux_cor = rf1d(flux, wave, channel=4) where `flux` is the extracted spectral data, and the data are from channel 4 for this example. + +Extraction for NIRISS SOSS Data +------------------------------- +For NIRISS SOSS data, the two spectral orders overlap slightly, so a specialized extraction +algorithm known as ATOCA (Algorithm to Treat Order ContAmination) is used... +Link paper diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 80440ed760..34fb487cca 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -545,13 +545,12 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, n_os : int, optional The oversampling factor of the wavelength grid used when solving for the uncontaminated flux. If not specified, defaults to 2. - wave_grid : str or SossWaveGridModel or None - Filename of reference file or SossWaveGridModel containing the wavelength grid used by ATOCA - to model each pixel valid pixel of the detector. If not given, the grid is determined - based on an estimate of the flux (estimate), the relative tolerance (rtol) - required on each pixel model and the maximum grid size (max_grid_size). - # TODO: none of the options work except None or an ndarray, including on main. - # docstring needs updates. + wave_grid : np.ndarray, optional + Wavelength grid used by ATOCA to model each pixel valid pixel of the detector. + If not given, the grid is determined based on an estimate of the flux (estimate), + the relative tolerance (rtol) required on each pixel model and + the maximum grid size (max_grid_size). + # TODO: none of the options specified on main work # Should we add support for these? If not, is SossWaveGridModel used for anything, # and can that be removed from stdatamodels and as a valid argument to soss_wave_grid_in? estimate : UnivariateSpline or None @@ -572,7 +571,8 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, logl : float Log likelihood value associated with the Tikhonov factor selected. wave_grid : 1d array - Same as wave_grid input. TODO: this isn't true if input wave_grid is None, update docstring + The wavelengths at which the spectra were extracted. Same as wave_grid + if specified as input. spec_list : list of SpecModel List of the underlying spectra for each integration and order. The tikhonov tests are also included. @@ -623,6 +623,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, if tikfac is None: log.info('Solving for the optimal Tikhonov factor.') + save_tiktests = True # Find the tikhonov factor. # Initial pass 8 orders of magnitude with 10 grid points. @@ -637,11 +638,9 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, factors = np.logspace(tikfac - 2, tikfac + 2, 20) tiktests = engine.get_tikho_tests(factors, scidata_bkg, scierr) tikfac = engine.best_tikho_factor(tiktests, fit_mode='d_chi2') - # Add all theses tests to previous ones all_tests = _append_tiktests(all_tests, tiktests) # Save spectra in a list of SingleSpecModels for optional output - save_tiktests = True for i_order, order in enumerate(order_list): for idx in range(len(all_tests['factors'])): f_k = all_tests['solution'][idx, :] @@ -686,9 +685,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Add the result to spec_list spec_list.append(spec_ord) - # ############################### - # Model remaining part of order 2 - # ############################### + # Model the remaining part of order 2 if ref_files['subarray'] != 'SUBSTRIP96': idx_order2 = 1 order = idx_order2 + 1 @@ -700,9 +697,6 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, # Mask for the fit. All valid pixels inside box aperture mask_fit = mask_trace_profile[idx_order2] | scimask -# # and extract only what was not already modeled -# already_modeled = np.isfinite(tracemodels[order_str]) -# mask_fit |= already_modeled # Build 1d spectrum integrated over pixels pixel_wave_grid, valid_cols = _get_native_grid_from_trace(ref_files, order) @@ -747,6 +741,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, def _compute_box_weights(ref_files, shape, width): + """Determine the weights for the box extraction.""" # Generate list of orders from pastasoss trace list order_list = [] @@ -1029,6 +1024,7 @@ def run_extract1d(input_model, pastasoss_ref_name, specprofile_ref_name, speckernel_ref_name, subarray, soss_filter, soss_kwargs): """Run the spectral extraction on NIRISS SOSS data. + Parameters ---------- input_model : DataModel diff --git a/jwst/extract_1d/soss_extract/tests/test_soss_extract.py b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py index e439874db6..faf9cf7d79 100644 --- a/jwst/extract_1d/soss_extract/tests/test_soss_extract.py +++ b/jwst/extract_1d/soss_extract/tests/test_soss_extract.py @@ -139,15 +139,6 @@ def test_model_image(monkeypatch_setup, assert np.all(n_good - n_good[0] == 0) - # check that if tikfac is defined, output spectra is a single-element list - tikfac_in = 1e-7 - tracemodels, tikfac, logl, wave_grid, spec_list = _model_image( - scidata, scierr, detector_mask, refmask, ref_files, box_weights, - tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=None, - estimate=None, rtol=1e-3, max_grid_size=1000000, - ) - - def test_model_image_tikfac_specified(monkeypatch_setup, imagemodel, detector_mask, @@ -165,6 +156,7 @@ def test_model_image_tikfac_specified(monkeypatch_setup, tikfac=tikfac_in, threshold=1e-4, n_os=2, wave_grid=None, estimate=None, rtol=1e-3, max_grid_size=1000000, ) + # check that spec_list is a single-element list per order in this case assert len(spec_list) == 3 assert tikfac == tikfac_in @@ -195,6 +187,7 @@ def test_model_image_wavegrid_specified(monkeypatch_setup, assert np.allclose(wave_grid, wave_grid_in) # test SossWaveGridModel input + # the docs on main say this works, but I don't think it does even on main with pytest.raises(ValueError): wave_grid_in = SossWaveGridModel() wave_grid_in.wavegrid = np.linspace(1.0, 2.5, 100) From e0c88830900c04462a5cd37664e9d2f08a90fd86 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 17 Dec 2024 17:03:18 -0500 Subject: [PATCH 35/35] fixed a few typos during self review --- jwst/extract_1d/soss_extract/atoca.py | 7 +++--- jwst/extract_1d/soss_extract/atoca_utils.py | 22 +++++++------------ jwst/extract_1d/soss_extract/soss_extract.py | 15 +++++-------- .../extract_1d/soss_extract/tests/conftest.py | 3 ++- 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 36fca3b4d3..c3b69f2c89 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -579,7 +579,6 @@ def get_detector_model(self, data, error): def tikho_mat(self): """ Return the Tikhonov matrix. - Generate it with `atoca_utils.get_tikho_matrix` method if not defined yet. """ if self._tikho_mat is not None: return self._tikho_mat @@ -900,9 +899,9 @@ def _get_lo_hi(self, grid, wave_p, wave_m, mask): grid : array[float] Wave_grid to check. wave_p : array[float] - TODO: add here + Wavelengths on the higher side of each pixel. wave_m : array[float] - TODO: add here + Wavelengths on the lower side of each pixel. Returns ------- @@ -985,7 +984,7 @@ def get_w(self, i_order): ma = mask_ord[~self.mask] # Get lo hi - lo, hi = self._get_lo_hi(wave_grid, wave_p, wave_m, ma) # Get indexes + lo, hi = self._get_lo_hi(wave_grid, wave_p, wave_m, ma) # Get indices # Number of used pixels n_i = len(lo) diff --git a/jwst/extract_1d/soss_extract/atoca_utils.py b/jwst/extract_1d/soss_extract/atoca_utils.py index d04953c78b..a99577646c 100644 --- a/jwst/extract_1d/soss_extract/atoca_utils.py +++ b/jwst/extract_1d/soss_extract/atoca_utils.py @@ -64,8 +64,6 @@ def arange_2d(starts, stops): def sparse_k(val, k, n_k): """ - TODO: ensure test coverage. Probably sufficient to have test for compute_weights - in atoca.py Transform a 2D array `val` to a sparse matrix. Parameters @@ -380,11 +378,11 @@ def grid_from_map(wave_map, trace_profile, wave_range=None, n_os=1): Minimum and maximum boundary of the grid to generate, in microns. Wave_range must include some wavelengths of wave_map. Note wave_range is exclusive, in the sense that wave_range[0] and wave_range[1] - will not be between min(output) and mad(output). Instead, min(output) will be + will not be between min(output) and max(output). Instead, min(output) will be the smallest value in the extrapolated grid that is greater than wave_range[0] and max(output) will be the largest value that is less than wave_range[1]. n_os : int - Oversampling of the grid compare to the pixel sampling. + Oversampling of the grid compared to the pixel sampling. Returns ------- @@ -470,9 +468,6 @@ def _trim_grids(all_grids, grid_range): def make_combined_adaptive_grid(all_grids, all_estimates, grid_range, max_iter=10, rtol=10e-6, max_total_size=1000000): """ - TODO: can this be a class? e.g., class AdaptiveGrid? - q: why are there multiple grids passed in here in the first place - Return an irregular oversampled grid needed to reach a given precision when integrating over each intervals of `grid`. The grid is built by subdividing iteratively each intervals that @@ -1085,8 +1080,6 @@ def _get_wings(fct, grid, h_len, i_a, i_b): def _trpz_weight(grid, length, shape, i_a, i_b): """ - TODO: add to some integration class? - Compute weights due to trapezoidal integration Parameters @@ -1466,7 +1459,7 @@ def _get_interp_idx_array(idx, relative_range, max_length): return np.arange(*abs_range, 1) -def _minimize_on_grid(factors, val_to_minimize, interpolate=True, interp_index=[-2,4]): +def _minimize_on_grid(factors, val_to_minimize, interpolate=True, interp_index=None): """ Find minimum of a grid using akima spline interpolation to get a finer estimate Parameters @@ -1486,6 +1479,8 @@ def _minimize_on_grid(factors, val_to_minimize, interpolate=True, interp_index=[ min_fac : float The factor with minimized error/curvature. """ + if interp_index is None: + interp_index = [-2, 4] # Only keep finite values idx_finite = np.isfinite(val_to_minimize) @@ -1528,7 +1523,7 @@ def _minimize_on_grid(factors, val_to_minimize, interpolate=True, interp_index=[ return min_fac -def _find_intersect(factors, y_val, thresh, interpolate=True, search_range=[0,3]): +def _find_intersect(factors, y_val, thresh, interpolate=True, search_range=None): """ Find the root of y_val - thresh (so the intersection between thresh and y_val) Parameters ---------- @@ -1551,6 +1546,8 @@ def _find_intersect(factors, y_val, thresh, interpolate=True, search_range=[0,3] Factor corresponding to the best approximation of the intersection point. """ + if search_range is None: + search_range = [0, 3] # Only keep finite values idx_finite = np.isfinite(y_val) @@ -1821,9 +1818,6 @@ def try_solve_two_methods(matrix, result): class Tikhonov: """ - TODO: can we avoid all of this by using scipy.optimize.least_squares - like this? https://stackoverflow.com/questions/62768131/how-to-add-tikhonov-regularization-in-scipy-optimize-least-squares - Tikhonov regularization to solve the ill-posed problem A.x = b, where A is accidentally singular or close to singularity. Tikhonov regularization adds a regularization term in the equation and aim to minimize the diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 34fb487cca..5bdfac74d4 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -92,7 +92,6 @@ def get_ref_file_args(ref_files): # The spectral kernels. speckernel_ref = ref_files['speckernel'] - # ovs = speckernel_ref.meta.spectral_oversampling n_pix = 2 * speckernel_ref.meta.halfwidth + 1 # Take the centroid of each trace as a grid to project the WebbKernel @@ -237,10 +236,10 @@ def _get_native_grid_from_trace(ref_files, spectral_order): Returns ------- - wave : - Grid of the pixels boundaries at the native sampling (1d array) - col : - The column number of the pixel + wave : + Grid of the pixels boundaries at the native sampling (1d array) + col : + The column number of the pixel """ # From wavelength solution @@ -389,8 +388,6 @@ def _populate_tikho_attr(spec, tiktests, idx, sp_ord): def _f_to_spec(f_order, grid_order, ref_file_args, pixel_grid, mask, sp_ord): """ - TODO: would it be better to pass the engine directly in here somehow? - Bin the flux to the pixel grid and build a SpecModel. Parameters @@ -631,7 +628,7 @@ def _model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, log_guess = np.log10(guess_factor) factors = np.logspace(log_guess - 4, log_guess + 4, 10) all_tests = engine.get_tikho_tests(factors, scidata_bkg, scierr) - tikfac= engine.best_tikho_factor(all_tests, fit_mode='all') + tikfac = engine.best_tikho_factor(all_tests, fit_mode='all') # Refine across 4 orders of magnitude. tikfac = np.log10(tikfac) @@ -854,7 +851,7 @@ def _model_single_order(data_order, err_order, ref_file_args, mask_fit, The last spectrum in the list of SpecModels lacks the "chi2", "chi2_soft_l1", "chi2_cauchy", and "reg" attributes, as these are only calculated for the intermediate models. The last spectrum is not necessarily identical to any of the spectra in the list, as it is reconstructed according to - mask_rebuild instead of fit respecting mask_fit. + mask_rebuild instead of fit respecting mask_fit; that is, bad pixels are included. # TODO are all of these behaviors for the last spec in the list the desired ones? """ diff --git a/jwst/extract_1d/soss_extract/tests/conftest.py b/jwst/extract_1d/soss_extract/tests/conftest.py index e339bfa4ec..3e48416145 100644 --- a/jwst/extract_1d/soss_extract/tests/conftest.py +++ b/jwst/extract_1d/soss_extract/tests/conftest.py @@ -19,7 +19,8 @@ - Randomly-selected bad pixels in the data - Wave grid of size ~100 with varying resolution - Triangle function throughput for each spectral order -- Kernel set to unity (for now) +- Kernel is also a triangle function peaking at the center, or else unity for certain tests +- (partial) Mock of the Pastasoss reference model """ PWCPOS = 245.85932900002442