From dfc8aeebf8900dc9034e710508fd5ff43aec0e57 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 13:50:53 -0800 Subject: [PATCH] cleaned up multislice preprocess --- .../iterative_multislice_ptychography.py | 242 ++++-------------- .../phase/iterative_ptychographic_methods.py | 131 +++++++++- .../iterative_singleslice_ptychography.py | 2 +- 3 files changed, 180 insertions(+), 195 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index bb46f25c8..92064871c 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -37,9 +37,7 @@ generate_batches, polar_aliases, polar_symbols, - spatial_frequencies, ) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) @@ -137,8 +135,8 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, - theta_x: float = 0, - theta_y: float = 0, + theta_x: float = None, + theta_y: float = None, middle_focus: bool = False, object_type: str = "complex", positions_mask: np.ndarray = None, @@ -244,92 +242,6 @@ def __init__( self._theta_x = theta_x self._theta_y = theta_y - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -340,7 +252,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None, @@ -432,13 +344,10 @@ def preprocess( ) ) - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) + if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + # preprocess datacube ( self._datacube, self._vacuum_probe_intensity, @@ -454,6 +363,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # calibrations self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, @@ -462,6 +372,13 @@ def preprocess( force_reciprocal_sampling=force_reciprocal_sampling, ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM ( self._com_measured_x, self._com_measured_y, @@ -476,6 +393,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # estimate rotation / transpose ( self._rotation_best_rad, self._rotation_best_transpose, @@ -497,15 +415,17 @@ def preprocess( **kwargs, ) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, + self._crop_mask, ) = self._normalize_diffraction_intensities( self._intensities, self._com_fitted_x, self._com_fitted_y, - crop_patterns, self._positions_mask, + crop_patterns, ) # explicitly delete namespace @@ -513,41 +433,29 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, ) - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] + # center probe positions self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 @@ -561,76 +469,25 @@ def preprocess( self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # Vectorized Patches + # set vectorized patches ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) self._probe_initial = self._probe.copy() self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -639,7 +496,7 @@ def preprocess( device=self._device, )._evaluate_ctf() - # Precomputed propagator arrays + # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, self.sampling, @@ -653,10 +510,12 @@ def preprocess( shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) @@ -664,11 +523,12 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - power=2, + power=power, chroma_boost=chroma_boost, ) @@ -681,7 +541,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - power=2, + power=power, chroma_boost=chroma_boost, ) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 683905c88..28a66d971 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1,11 +1,16 @@ -from typing import Tuple +from typing import Sequence, Tuple import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.process.phase.utils import AffineTransform, ComplexProbe, rotate_point -from py4DSTEM.process.utils import get_CoM, get_shifted_ar +from py4DSTEM.process.phase.utils import ( + AffineTransform, + ComplexProbe, + rotate_point, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import gaussian_filter, rotate @@ -187,6 +192,126 @@ class Object2p5DMethodsMixin: Overwrites ObjectNDMethodsMixin. """ + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + theta_x: float = None, + theta_y: float = None, + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + theta_x: float, optional + x tilt of propagator (in degrees) + theta_y: float, optional + y tilt of propagator (in degrees) + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + if theta_x is not None: + theta_x = np.deg2rad(theta_x) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + + if theta_y is not None: + theta_y = np.deg2rad(theta_y) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def _initialize_object( + self, + initial_object, + num_slices, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((num_slices, p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((num_slices, p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + def _return_projected_cropped_potential( self, ): diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 22f2c2a46..acd570424 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -198,7 +198,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None,