diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index be6332c74..8265c1325 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -225,6 +225,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_probe_overlaps: bool = True, rotation_real_space_degrees: float = None, @@ -266,6 +267,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_probe_overlaps: bool, optional @@ -479,12 +483,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index aa89d09b3..975f6ac84 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -208,6 +208,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_rotation: bool = True, maximize_divergence: bool = False, @@ -259,6 +260,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_rotation: bool, optional @@ -373,10 +377,6 @@ def preprocess( f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." ) - dc_shapes = [dc.shape for dc in self._datacube] - if dc_shapes.count(dc_shapes[0]) != self._num_measurements: - raise ValueError("datacube intensities must be the same size.") - if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") @@ -551,12 +551,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index d82a37eb4..3bacf1870 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -267,6 +267,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -318,6 +319,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -486,17 +490,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 7bbadf114..9b12d09e0 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -213,6 +213,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -264,6 +265,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -432,17 +436,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 87e3c1fe4..65a347b83 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -241,6 +241,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -292,6 +293,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -460,17 +464,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index fa86f4dc2..e5768f3cc 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -19,13 +19,15 @@ AffineTransform, bilinear_kernel_density_estimate, bilinearly_interpolate_array, + calculate_aberration_gradient_basis, + generate_batches, lanczos_interpolate_array, lanczos_kernel_density_estimate, pixel_rolling_kernel_density_estimate, ) from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from py4DSTEM.visualize import return_scaled_histogram_ordering, show +from py4DSTEM.visualize import return_scaled_histogram_ordering from scipy.linalg import polar from scipy.ndimage import distance_transform_edt from scipy.optimize import minimize @@ -260,12 +262,14 @@ def preprocess( descan_correction_fit_function: str = None, defocus_guess: float = None, rotation_guess: float = None, + aligned_bf_image_guess: np.ndarray = None, plot_average_bf: bool = True, realspace_mask: np.ndarray = None, apply_realspace_mask_to_stack: bool = True, vectorized_com_calculation: bool = True, device: str = None, clear_fft_cache: bool = None, + max_batch_size: int = None, store_initial_arrays: bool = True, **kwargs, ): @@ -284,16 +288,18 @@ def preprocess( If True, bright images normalized to have a mean of 1 normalize_order: integer, optional Polynomial order for normalization. 0 means constant, 1 means linear, etc. - Higher orders not yet implemented. defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + aligned_bf_image_guess: np.ndarray, optional + Guess for the reference BF image to cross-correlate against during the first iteration + If None, the incoherent BF image is used instead. + rotation_guess: float, optional + Initial guess of rotation value in degrees + If None, first iteration assumed to be 0 descan_correction_fit_function: str, optional If not None, descan correction will be performed using fit function. One of "constant", "plane", "parabola", or "bezier_two". - rotation_guess: float, optional - Initial guess of defocus value in degrees - If None, first iteration assumed to be 0 plot_average_bf: bool, optional If True, plots the average bright field image, using defocus_guess realspace_mask: np.array, optional @@ -308,6 +314,8 @@ def preprocess( if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation store_initial_arrays: bool, optional If True, stores a copy of the arrays necessary to reinitialize in reconstruct @@ -474,7 +482,6 @@ def preprocess( self._stack_BF_unshifted = xp.ones(stack_shape, xp.float32) if normalize_order == 0: - # all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] weights = xp.average( all_bfs.reshape((self._num_bf_images, -1)), weights=self._window_edge.ravel(), @@ -517,7 +524,6 @@ def preprocess( weights = np.sqrt(self._window_edge).ravel() for a0 in range(all_bfs.shape[0]): - # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) # weighted least squares coefs = np.linalg.lstsq( weights[:, None] * basis, @@ -574,7 +580,6 @@ def preprocess( weights = np.sqrt(self._window_edge).ravel() for a0 in range(all_bfs.shape[0]): - # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) # weighted least squares coefs = np.linalg.lstsq( weights[:, None] * basis, @@ -645,49 +650,84 @@ def preprocess( # Initialization utilities self._stack_mask = xp.tile(self._window_pad[None], (self._num_bf_images, 1, 1)) + + if max_batch_size is None: + max_batch_size = self._num_bf_images + + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) + if defocus_guess is not None: - Gs = xp.fft.fft2(self._stack_BF_shifted) + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted[start:end] + probe_angles = self._probe_angles[start:end] + stack_mask = self._stack_mask[start:end] - self._xy_shifts = ( - -self._probe_angles - * defocus_guess - / xp.array(self._scan_sampling, dtype=xp.float32) - ) + Gs = xp.fft.fft2(shifted_BFs) - if rotation_guess: - angle = xp.deg2rad(rotation_guess) - rotation_matrix = xp.array( - [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]], - dtype=xp.float32, + xy_shifts = ( + -probe_angles + * defocus_guess + / xp.array(self._scan_sampling, dtype=xp.float32) ) - self._xy_shifts = xp.dot(self._xy_shifts, rotation_matrix) - dx = self._xy_shifts[:, 0] - dy = self._xy_shifts[:, 1] + if rotation_guess is not None: + angle = xp.deg2rad(rotation_guess) + rotation_matrix = xp.array( + [ + [np.cos(angle), np.sin(angle)], + [-np.sin(angle), np.cos(angle)], + ], + dtype=xp.float32, + ) + xy_shifts = xp.dot(xy_shifts, rotation_matrix) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] - ) - self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - self._stack_mask = xp.real( - xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) - ) + dx = xy_shifts[:, 0] + dy = xy_shifts[:, 1] - del Gs - else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + stack_mask = xp.real(xp.fft.ifft2(xp.fft.fft2(stack_mask) * shift_op)) + + self._xy_shifts[start:end] = xy_shifts + self._stack_BF_shifted[start:end] = stack_BF_shifted + self._stack_mask[start:end] = stack_mask + + del Gs self._stack_mean = xp.mean(self._stack_BF_shifted) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images - self._recon_mask = xp.sum(self._stack_mask, axis=0) + self._recon_mask = xp.mean(self._stack_mask, axis=0) mask_inv = 1 - xp.clip(self._recon_mask, 0, 1) - self._recon_BF = ( - self._stack_mean * mask_inv - + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) - ) / (self._recon_mask + mask_inv) + if aligned_bf_image_guess is not None: + aligned_bf_image_guess = xp.asarray(aligned_bf_image_guess) + if normalize_images: + self._recon_BF = xp.ones(stack_shape[-2:], dtype=xp.float32) + aligned_bf_image_guess /= aligned_bf_image_guess.mean() + else: + self._recon_BF = xp.full(stack_shape[-2:], self._stack_mean) + + self._recon_BF[ + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = ( + self._window_inv * self._stack_mean + + self._window_edge * aligned_bf_image_guess + ) + + else: + self._recon_BF = ( + self._stack_mean * mask_inv + + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) + ) / (self._recon_mask + mask_inv) self._recon_error = ( xp.atleast_1d( @@ -697,6 +737,7 @@ def preprocess( ) ) / self._mask_sum + / self._stack_mean ) if store_initial_arrays: @@ -737,6 +778,240 @@ def preprocess( return self + def guess_common_aberrations( + self, + rotation_angle_deg=0, + transpose=False, + kde_upsample_factor=None, + kde_sigma_px=0.125, + kde_lowpass_filter=False, + lanczos_interpolation_order=None, + defocus=0, + astigmatism=0, + astigmatism_angle_deg=0, + coma=0, + coma_angle_deg=0, + spherical_aberration=0, + max_batch_size=None, + plot_shifts_and_aligned_bf=True, + return_shifts_and_aligned_bf=False, + plot_arrow_freq=1, + scale_arrows=1, + **kwargs, + ): + """ + Generates analytical BF shifts and uses them to align the virtual BF stack, + based on the experimental geometry (rotation, transpose), and common aberrations. + + Parameters + ---------- + rotation_angle_deg: float, optional + Relative rotation between the scan and the diffraction space coordinate systems + transpose: bool, optional + Whether the diffraction intensities are transposed + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma_px: float, optional + KDE gaussian kernel bandwidth in non-upsampled pixels + kde_lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + lanczos_interpolation_order: int, optional + If not None, Lanczos interpolation with the specified order is used instead of bilinear + defocus: float, optional + Defocus value to use in computing analytical BF shifts + astigmatism: float, optional + Astigmatism value to use in computing analytical BF shifts + astigmatism_angle_deg: float, optional + Astigmatism angle to use in computing analytical BF shifts + coma: float, optional + Coma value to use in computing analytical BF shifts + coma_angle_deg: float, optional + Coma angle to use in computing analytical BF shifts + spherical_aberration: float, optional + Spherical aberration value to use in computing analytical BF shifts + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation + plot_shifts_and_aligned_bf: bool, optional + If True, the analytical shifts and the aligned virtual VF image are plotted + return_shifts_and_aligned_bf: bool, optional + If True, the analytical shifts and the aligned virtual VF image are returned + plot_arrow_freq: int, optional + Frequency of shifts to plot in quiver plot + scale_arrows: float, optional + Scale to multiply shifts by + + """ + xp = self._xp + asnumpy = self._asnumpy + + if not hasattr(self, "_recon_BF"): + raise ValueError( + ( + "Aberration guessing is meant to be ran after preprocessing. " + "Please run the `preprocess()` function first." + ) + ) + + # aberrations_coefs + aberrations_mn = [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 1], + [2, 1, 0], + [2, 1, 1], + [3, 0, 0], + ] + astigmatism_x = astigmatism * np.cos(np.deg2rad(astigmatism_angle_deg) * 2) + astigmatism_y = astigmatism * np.sin(np.deg2rad(astigmatism_angle_deg) * 2) + coma_x = coma * np.cos(np.deg2rad(coma_angle_deg) * 1) + coma_y = coma * np.sin(np.deg2rad(coma_angle_deg) * 1) + aberrations_coefs = xp.array( + [ + -defocus, + astigmatism_x, + astigmatism_y, + coma_x, + coma_y, + spherical_aberration, + ] + ) + + # transpose rotation matrix + if transpose: + rotation_angle_deg *= -1 + + # aberrations_basis + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + aberrations_basis, aberrations_basis_du, aberrations_basis_dv = ( + calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=np.deg2rad(rotation_angle_deg), + xp=xp, + ) + ) + + # shifts + corner_indices = self._xy_inds - xp.array(self._region_of_interest_shape // 2) + raveled_indices = xp.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.array( + ( + aberrations_basis_du[raveled_indices, :], + aberrations_basis_dv[raveled_indices, :], + ) + ) + shifts_ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T + + # transpose predicted shifts + if transpose: + shifts_ang = xp.flip(shifts_ang, axis=1) + + shifts_px = shifts_ang / xp.array(self._scan_sampling) + + # upsampled stack + if kde_upsample_factor is not None: + BF_size = np.array(self._stack_BF_unshifted.shape[-2:]) + pixel_output_shape = np.round(BF_size * kde_upsample_factor).astype("int") + + x = xp.arange(BF_size[0], dtype=xp.float32) + y = xp.arange(BF_size[1], dtype=xp.float32) + xa_init, ya_init = xp.meshgrid(x, y, indexing="ij") + + # kernel density output the upsampled BF image + xa = (xa_init + shifts_px[:, 0, None, None]) * kde_upsample_factor + ya = (ya_init + shifts_px[:, 1, None, None]) * kde_upsample_factor + + pix_output = self._kernel_density_estimate( + xa, + ya, + self._stack_BF_unshifted, + pixel_output_shape, + kde_sigma_px * kde_upsample_factor, + lanczos_alpha=lanczos_interpolation_order, + lowpass_filter=kde_lowpass_filter, + ) + + # hack since cropping requires "_kde_upsample_factor" + old_upsample_factor = getattr(self, "_kde_upsample_factor", None) + self._kde_upsample_factor = kde_upsample_factor + cropped_image = asnumpy( + self._crop_padded_object(pix_output, upsampled=True) + ) + if old_upsample_factor is not None: + self._kde_upsample_factor = old_upsample_factor + else: + del self._kde_upsample_factor + + # shifted stack + else: + kde_upsample_factor = 1 + aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0]) + + if max_batch_size is None: + max_batch_size = self._num_bf_images + + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted_initial[start:end] + + Gs = xp.fft.fft2(shifted_BFs) + + dx = shifts_px[start:end, 0] + dy = shifts_px[start:end, 1] + + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + aligned_stack += stack_BF_shifted.sum(0) + + cropped_image = asnumpy( + self._crop_padded_object(aligned_stack, upsampled=False) + ) + + if plot_shifts_and_aligned_bf: + figsize = kwargs.pop("figsize", (8, 4)) + color = kwargs.pop("color", (1, 0, 0, 1)) + cmap = kwargs.pop("cmap", "magma") + + fig, axs = plt.subplots(1, 2, figsize=figsize) + + self.show_shifts( + shifts_ang=shifts_ang, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + plot_rotated_shifts=False, + color=color, + figax=(fig, axs[0]), + ) + + axs[0].set_title("Predicted BF Shifts") + + extent = [ + 0, + self._scan_sampling[1] * cropped_image.shape[1] / kde_upsample_factor, + self._scan_sampling[0] * cropped_image.shape[0] / kde_upsample_factor, + 0, + ] + + axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs) + axs[1].set_ylabel("x [A]") + axs[1].set_xlabel("y [A]") + axs[1].set_title("Predicted Aligned BF Image") + + fig.tight_layout() + + if return_shifts_and_aligned_bf: + return shifts_ang, cropped_image + def reconstruct( self, max_alignment_bin: int = None, @@ -754,6 +1029,7 @@ def reconstruct( reset: bool = None, device: str = None, clear_fft_cache: bool = None, + max_batch_size: int = None, **kwargs, ): """ @@ -788,6 +1064,8 @@ def reconstruct( If True, the reconstruction is reset device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. + max_batch_size: int, optional + Max number of virtual BF images to use at once in computing cross-correlation clear_fft_cache: bool, optional if true, and device = 'gpu', clears the cached fft plan at the end of function calls @@ -911,6 +1189,9 @@ def reconstruct( xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float") + if max_batch_size is None: + max_batch_size = self._num_bf_images + # Loop over all binning values for a0 in range(bin_vals.shape[0]): G_ref = xp.fft.fft2(self._recon_BF) @@ -981,31 +1262,33 @@ def reconstruct( shifts_update = xy_shifts_fit - self._xy_shifts # apply shifts - Gs = xp.fft.fft2(self._stack_BF_shifted) + for start, end in generate_batches( + self._num_bf_images, max_batch=max_batch_size + ): + shifted_BFs = self._stack_BF_shifted[start:end] + stack_mask = self._stack_mask[start:end] - dx = shifts_update[:, 0] - dy = shifts_update[:, 1] - self._xy_shifts[:, 0] += dx - self._xy_shifts[:, 1] += dy + Gs = xp.fft.fft2(shifted_BFs) - shift_op = xp.exp( - self._qx_shift[None] * dx[:, None, None] - + self._qy_shift[None] * dy[:, None, None] - ) + dx = shifts_update[start:end, 0] + dy = shifts_update[start:end, 1] - self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) - self._stack_mask = xp.real( - xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) - ) + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] + ) - self._stack_BF_shifted = xp.asarray( - self._stack_BF_shifted, dtype=xp.float32 - ) # numpy fft upcasts? - self._stack_mask = xp.asarray( - self._stack_mask, dtype=xp.float32 - ) # numpy fft upcasts? + stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) + stack_mask = xp.real(xp.fft.ifft2(xp.fft.fft2(stack_mask) * shift_op)) - del Gs + self._stack_BF_shifted[start:end] = xp.asarray( + stack_BF_shifted, dtype=xp.float32 + ) + self._stack_mask[start:end] = xp.asarray(stack_mask, dtype=xp.float32) + self._xy_shifts[start:end, 0] += dx + self._xy_shifts[start:end, 1] += dy + + del Gs # Center the shifts xy_shifts_median = xp.round(xp.median(self._xy_shifts, axis=0)).astype(int) @@ -1016,12 +1299,12 @@ def reconstruct( self._stack_mask = xp.roll(self._stack_mask, -xy_shifts_median, axis=(1, 2)) # Generate new estimate - self._recon_mask = xp.sum(self._stack_mask, axis=0) + self._recon_mask = xp.mean(self._stack_mask, axis=0) mask_inv = 1 - np.clip(self._recon_mask, 0, 1) self._recon_BF = ( self._stack_mean * mask_inv - + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) + + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) ) / (self._recon_mask + mask_inv) self._recon_error = ( @@ -1032,6 +1315,7 @@ def reconstruct( ) ) / self._mask_sum + / self._stack_mean ) self.error_iterations.append(float(self._recon_error)) @@ -2045,75 +2329,21 @@ def calculate_CTF_FFT(alpha_shape, *coefs): # Direct Shifts Fitting if fit_BF_shifts: - # FFT coordinates - sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) - sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) - qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) - qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) - qx, qy = np.meshgrid(qx, qy, indexing="ij") - - # passive rotation basis by -theta - rotation_angle = -self.rotation_Q_to_R_rads - qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( - rotation_angle - ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) - - qr2 = qx**2 + qy**2 - u = qx * self._wavelength - v = qy * self._wavelength - alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy, qx) - - # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) - self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) - self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) - for a0 in range(self._aberrations_num): - m, n, a = self._aberrations_mn[a0] - - if n == 0: - # Radially symmetric basis - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() - self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() - - elif a == 0: - # cos coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) - / (m + 1) - ).ravel() - self._aberrations_basis_dv[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) - / (m + 1) - ).ravel() - - else: - # sin coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) - ).ravel() - self._aberrations_basis_du[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) - / (m + 1) - ).ravel() - self._aberrations_basis_dv[:, a0] = ( - alpha ** (m - 1) - * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) - / (m + 1) - ).ravel() - - # global scaling - self._aberrations_basis *= 2 * np.pi / self._wavelength - self._aberrations_surface_shape = alpha.shape + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + ( + self._aberrations_babis, + self._aberrations_basis_du, + self._aberrations_basis_dv, + ) = calculate_aberration_gradient_basis( + self._aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=self.rotation_Q_to_R_rads, + xp=xp, + ) # CTF function def calculate_CTF(alpha_shape, *coefs): @@ -2202,19 +2432,6 @@ def score_CTF(coefs): # Plot the measured/fitted shifts comparison if plot_BF_shifts_comparison: - measured_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 0] - ) - - measured_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - self._xy_shifts_Ang[:, 1] - ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -2222,53 +2439,28 @@ def score_CTF(coefs): .T ) - fitted_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 0] - ) + scale_arrows = kwargs.pop("scale_arrows", 1) + plot_arrow_freq = kwargs.pop("plot_arrow_freq", 1) + figsize = kwargs.pop("figsize", (4, 4)) - fitted_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( - fitted_shifts[:, 1] - ) + fig, ax = plt.subplots(figsize=figsize) - max_shift = xp.max( - xp.array( - [ - xp.abs(measured_shifts_sx).max(), - xp.abs(measured_shifts_sy).max(), - xp.abs(fitted_shifts_sx).max(), - xp.abs(fitted_shifts_sy).max(), - ] - ) + self.show_shifts( + shifts_ang=self._xy_shifts_Ang, + plot_rotated_shifts=False, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + color=(1, 0, 0, 0.5), + figax=(fig, ax), ) - axsize = kwargs.pop("axsize", (4, 4)) - cmap = kwargs.pop("cmap", "PiYG") - vmin = kwargs.pop("vmin", -max_shift) - vmax = kwargs.pop("vmax", max_shift) - - show( - [ - [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], - [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], - ], - cmap=cmap, - vmin=vmin, - vmax=vmax, - intensity_range="absolute", - axsize=axsize, - ticks=False, - title=[ - "Measured Vertical Shifts", - "Fitted Vertical Shifts", - "Measured Horizontal Shifts", - "Fitted Horizontal Shifts", - ], + self.show_shifts( + shifts_ang=fitted_shifts, + plot_rotated_shifts=False, + plot_arrow_freq=plot_arrow_freq, + scale_arrows=scale_arrows, + color=(0, 0, 1, 0.5), + figax=(fig, ax), ) # Plot the CTF comparison between experiment and fit @@ -2826,9 +3018,11 @@ def _visualize_figax( def show_shifts( self, + shifts_ang=None, scale_arrows=1, plot_arrow_freq=1, plot_rotated_shifts=True, + figax=None, **kwargs, ): """ @@ -2836,31 +3030,58 @@ def show_shifts( Parameters ---------- + shifts_ang: np.ndarray, optional + If None, self._xy_shifts is used scale_arrows: float, optional Scale to multiply shifts by plot_arrow_freq: int, optional Frequency of shifts to plot in quiver plot + plot_rotated_shifts: bool, optional + If True, shifts are plotted with the relative rotation decomposed + figax: optional + Tuple of figure, axes to plot against """ xp = self._xp asnumpy = self._asnumpy color = kwargs.pop("color", (1, 0, 0, 1)) + + if shifts_ang is None: + shifts_px = self._xy_shifts + else: + shifts_px = shifts_ang / xp.array(self._scan_sampling) + + shifts = shifts_px * scale_arrows * xp.array(self._reciprocal_sampling) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): - figsize = kwargs.pop("figsize", (8, 4)) - fig, ax = plt.subplots(1, 2, figsize=figsize) - scaling_factor = ( - xp.array(self._reciprocal_sampling) - / xp.array(self._scan_sampling) - * scale_arrows + + if figax is None: + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + else: + fig, ax = figax + + rotated_color = kwargs.pop("rotated_color", (0, 0, 0, 1)) + + if shifts_ang is None: + rotated_shifts_px = self._xy_shifts.copy() + else: + rotated_shifts_px = shifts_ang / xp.array(self._scan_sampling) + + if self.transpose: + rotated_shifts_px = xp.flip(rotated_shifts_px, axis=1) + + rotated_shifts = ( + rotated_shifts_px * scale_arrows * xp.array(self._reciprocal_sampling) ) - rotated_shifts = self._xy_shifts_Ang * scaling_factor else: - figsize = kwargs.pop("figsize", (4, 4)) - fig, ax = plt.subplots(figsize=figsize) - - shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] + if figax is None: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + else: + fig, ax = figax dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -2903,6 +3124,7 @@ def show_shifts( angles="xy", scale_units="xy", scale=1, + color=rotated_color, **kwargs, ) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 149e8143b..886fd7972 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -135,7 +135,7 @@ def grid_search( Parameters ---------- - n_initial_points: int + n_points: int Number of uniformly spaced trial points to run on a grid error_metric: Callable or str Function used to compute the reconstruction error. @@ -233,7 +233,7 @@ def evaluation_callback(ptycho): ax.imshow(res[0], cmap=cmap) title_substrings = [ - f"{param.name}: {val}" + f"{param.name}: {val:.3e}" for param, val in zip(self._parameter_list, params) ] title_substrings.append(f"error: {res[1]:.3e}") @@ -458,7 +458,7 @@ def _split_static_and_optimization_vars(self, argdict): return static_args, optimization_args def _get_scan_positions(self, affine_transform, dataset): - scan_positions = self._init_static_args.pop("initial_scan_positions", None) + scan_positions = self._init_static_args.get("initial_scan_positions", None) if scan_positions is None: R_pixel_size = dataset.calibration.get_R_pixel_size() x, y = ( @@ -485,8 +485,10 @@ def _get_error_metric(self, error_metric: Union[Callable, str]) -> Callable: "log-converged", "linear-converged", "TV", + "TV-phase", "std", "std-phase", + "entropy", "entropy-phase", ), f"Error metric {error_metric} not recognized." @@ -519,10 +521,20 @@ def f(ptycho): elif error_metric == "TV": def f(ptycho): - gx, gy = np.gradient(ptycho.object_cropped, axis=(-2, -1)) - obj_mag = np.sum(np.abs(ptycho.object_cropped)) + array = np.abs(ptycho.object_cropped) + gx = array[..., 1:, :] - array[..., -1:, :] + gy = array[..., :, 1:] - array[..., :, -1:] + tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) + return tv / array.size + + elif error_metric == "TV-phase": + + def f(ptycho): + array = np.angle(ptycho.object_cropped) + gx = array[..., 1:, :] - array[..., -1:, :] + gy = array[..., :, 1:] - array[..., :, -1:] tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) - return tv / obj_mag + return tv / array.size elif error_metric == "std": @@ -534,16 +546,30 @@ def f(ptycho): def f(ptycho): return -np.std(np.angle(ptycho.object_cropped)) + elif error_metric == "entropy": + + def f(ptycho): + array = np.abs(ptycho.object_cropped) + normalized_array = (array - np.min(array)) / np.ptp(array) + # gx = normalized_array[..., 1:, :] - normalized_array[..., -1:, :] + # gy = normalized_array[..., :, 1:] - normalized_array[..., :, -1:] + gx, gy = np.gradient(normalized_array, axis=(-2, -1)) + ghist, _, _ = np.histogram2d(gx.ravel(), gy.ravel(), bins=array.shape) + ghist = ghist[ghist > 0] / array.size + S = np.sum(ghist * np.log2(ghist)) + return S + elif error_metric == "entropy-phase": def f(ptycho): - obj = np.angle(ptycho.object_cropped) - gx, gy = np.gradient(obj) - ghist, _, _ = np.histogram2d( - gx.ravel(), gy.ravel(), bins=obj.shape, density=True - ) - nz = ghist > 0 - S = np.sum(ghist[nz] * np.log2(ghist[nz])) + array = np.angle(ptycho.object_cropped) + normalized_array = (array - np.min(array)) / np.ptp(array) + # gx = normalized_array[..., 1:, :] - normalized_array[..., -1:, :] + # gy = normalized_array[..., :, 1:] - normalized_array[..., :, -1:] + gx, gy = np.gradient(normalized_array, axis=(-2, -1)) + ghist, _, _ = np.histogram2d(gx.ravel(), gy.ravel(), bins=array.shape) + ghist = ghist[ghist > 0] / array.size + S = np.sum(ghist * np.log2(ghist)) return S else: diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 27476cb43..c571dcd3d 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1351,6 +1351,7 @@ def _normalize_diffraction_intensities( com_fitted_y, positions_mask, crop_patterns, + in_place_datacube_modification, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1363,78 +1364,67 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient - positions_mask: np.ndarray, optional + positions_mask: np.ndarray Boolean real space mask to select positions in datacube to skip for reconstruction crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns - when centering + If True, patterns are cropped to avoid wrap around of patterns + in_place_datacube_modification: bool + If True, the diffraction intensities are modified in-place Returns ------- - amplitudes: (Rx * Ry, Sx, Sy) np.ndarray + diffraction_intensities: (Rx * Ry, Sx, Sy) np.ndarray Flat array of normalized diffraction amplitudes mean_intensity: float Mean intensity value + crop_mask + Mask to crop diffraction patterns with """ # explicit read-only self attributes up-front asnumpy = self._asnumpy mean_intensity = 0 - - diffraction_intensities = asnumpy(diffraction_intensities) com_fitted_x = asnumpy(com_fitted_x) com_fitted_y = asnumpy(com_fitted_y) - if positions_mask is not None: - number_of_patterns = np.count_nonzero(positions_mask.ravel()) + if in_place_datacube_modification: + diff_intensities = diffraction_intensities else: - number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + diff_intensities = diffraction_intensities.copy() # Aggressive cropping for when off-centered high scattering angle data was recorded if crop_patterns: crop_x = int( np.minimum( - diffraction_intensities.shape[2] - com_fitted_x.max(), + diff_intensities.shape[2] - com_fitted_x.max(), com_fitted_x.min(), ) ) crop_y = int( np.minimum( - diffraction_intensities.shape[3] - com_fitted_y.max(), + diff_intensities.shape[3] - com_fitted_y.max(), com_fitted_y.min(), ) ) crop_w = np.minimum(crop_y, crop_x) - diffraction_intensities_shape_crop = (crop_w * 2, crop_w * 2) - amplitudes = np.zeros( - ( - number_of_patterns, - crop_w * 2, - crop_w * 2, - ), - dtype=np.float32, - ) - crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask = np.zeros(diff_intensities.shape[-2:], dtype="bool") crop_mask[:crop_w, :crop_w] = True crop_mask[-crop_w:, :crop_w] = True crop_mask[:crop_w:, -crop_w:] = True crop_mask[-crop_w:, -crop_w:] = True + crop_mask_shape = (crop_w * 2, crop_w * 2) + else: crop_mask = None - diffraction_intensities_shape_crop = diffraction_intensities.shape[-2:] - amplitudes = np.zeros( - (number_of_patterns,) + diffraction_intensities_shape_crop, - dtype=np.float32, - ) + crop_mask_shape = diff_intensities.shape[-2:] - counter = 0 for rx, ry in tqdmnd( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + diff_intensities.shape[0], + diff_intensities.shape[1], desc="Normalizing amplitudes", unit="probe position", disable=not self._verbose, @@ -1442,28 +1432,32 @@ def _normalize_diffraction_intensities( if positions_mask is not None: if not positions_mask[rx, ry]: continue + intensities = get_shifted_ar( - diffraction_intensities[rx, ry], + diff_intensities[rx, ry], -com_fitted_x[rx, ry], -com_fitted_y[rx, ry], bilinear=True, device="cpu", ) - if crop_patterns: - intensities = intensities[crop_mask].reshape( - diffraction_intensities_shape_crop - ) - mean_intensity += np.sum(intensities) - amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) - counter += 1 + diff_intensities[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - mean_intensity /= amplitudes.shape[0] + if positions_mask is not None: + diff_intensities = diff_intensities[positions_mask] + else: + qx, qy = diff_intensities.shape[-2:] + diff_intensities = diff_intensities.reshape((-1, qx, qy)) + + if crop_patterns: + diff_intensities = diff_intensities[:, crop_mask].reshape( + (-1,) + crop_mask_shape + ) - self._diffraction_intensities_shape_crop = diffraction_intensities_shape_crop + mean_intensity /= diff_intensities.shape[0] - return amplitudes, mean_intensity, crop_mask + return diff_intensities, mean_intensity, crop_mask, crop_mask_shape def show_complex_CoM( self, @@ -1557,6 +1551,7 @@ def to_h5(self, group): "semiangle_cutoff": self._semiangle_cutoff, "rolloff": self._rolloff, "object_padding_px": self._object_padding_px, + "object_fov_ang": self._object_fov_ang, "object_type": self._object_type, "verbose": self._verbose, "device": self._device, diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index b9eae9385..2e47a5e23 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1022,6 +1022,7 @@ def _initialize_probe( device = self._device crop_mask = self._crop_mask + crop_mask_shape = self._crop_mask_shape region_of_interest_shape = self._region_of_interest_shape sampling = self.sampling energy = self._energy @@ -1049,7 +1050,7 @@ def _initialize_probe( if crop_patterns: vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( - self._diffraction_intensities_shape_crop + crop_mask_shape ) sx, sy = vacuum_probe_intensity.shape diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 3639096dc..037ef4849 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -219,6 +219,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_probe_overlaps: bool = True, rotation_real_space_degrees: float = None, @@ -261,6 +262,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_probe_overlaps: bool, optional @@ -478,6 +482,7 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index b220ba741..d391dd293 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -195,6 +195,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -236,6 +237,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_center_of_mass: str, optional @@ -405,17 +409,21 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, + in_place_datacube_modification, ) # explicitly transfer arrays to storage + if not in_place_datacube_modification: + del _intensities + self._amplitudes = copy_to_device(self._amplitudes, storage) - del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index f932b40b5..bb960da62 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1504,6 +1504,85 @@ def step_model(radius, sig_0, rad_0, width): return probe_corr, polar_int, polar_int_corr, coefs_all +def calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + gpts, + wavelength, + rotation_angle=0, + xp=np, +): + """ """ + sx, sy = sampling + nx, ny = gpts + qx = xp.fft.fftfreq(nx, sx) + qy = xp.fft.fftfreq(ny, sy) + qx, qy = xp.meshgrid(qx, qy, indexing="ij") + + # passive rotation + qx, qy = qx * xp.cos(-rotation_angle) + qy * xp.sin(-rotation_angle), -qx * xp.sin( + -rotation_angle + ) + qy * xp.cos(-rotation_angle) + + # coordinate system + qr2 = qx**2 + qy**2 + u = qx * wavelength + v = qy * wavelength + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy, qx) + + _aberrations_n = len(aberrations_mn) + _aberrations_basis = xp.zeros((alpha.size, _aberrations_n)) + _aberrations_basis_du = xp.zeros((alpha.size, _aberrations_n)) + _aberrations_basis_dv = xp.zeros((alpha.size, _aberrations_n)) + + for a0 in range(_aberrations_n): + m, n, a = aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + _aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + _aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + _aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + _aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + _aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + _aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + _aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + _aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + _aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + _aberrations_basis *= 2 * np.pi / wavelength + + return _aberrations_basis, _aberrations_basis_du, _aberrations_basis_dv + + def aberrations_basis_function( probe_size, probe_sampling, diff --git a/py4DSTEM/process/phase/xray_magnetic_ptychography.py b/py4DSTEM/process/phase/xray_magnetic_ptychography.py index 91c0bbfaa..b1b8a5862 100644 --- a/py4DSTEM/process/phase/xray_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/xray_magnetic_ptychography.py @@ -206,6 +206,7 @@ def preprocess( padded_diffraction_intensities_shape: Tuple[int, int] = None, region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, + in_place_datacube_modification: bool = False, fit_function: str = "plane", plot_rotation: bool = True, maximize_divergence: bool = False, @@ -257,6 +258,9 @@ def preprocess( at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) + in_place_datacube_modification: bool, optional + If True, the datacube will be preprocessed in-place. Note this is not possible + when either crop_patterns or positions_mask are used. fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' plot_rotation: bool, optional @@ -371,10 +375,6 @@ def preprocess( f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." ) - dc_shapes = [dc.shape for dc in self._datacube] - if dc_shapes.count(dc_shapes[0]) != self._num_measurements: - raise ValueError("datacube intensities must be the same size.") - if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") @@ -549,12 +549,14 @@ def preprocess( amplitudes, mean_diffraction_intensity_temp, self._crop_mask, + self._crop_mask_shape, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, self._positions_mask[index], crop_patterns, + in_place_datacube_modification, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp)