From c9156dc9123f61aebc79ac3a5f50dd720ef5a17a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 7 Jan 2024 22:30:32 -0800 Subject: [PATCH] added detector plane resampling --- py4DSTEM/preprocess/preprocess.py | 4 +- .../magnetic_ptychographic_tomography.py | 31 ++++++-- .../process/phase/magnetic_ptychography.py | 33 +++++--- .../mixedstate_multislice_ptychography.py | 22 ++++-- .../process/phase/mixedstate_ptychography.py | 17 +++- .../process/phase/multislice_ptychography.py | 22 ++++-- py4DSTEM/process/phase/phase_base_class.py | 21 +++-- .../process/phase/ptychographic_methods.py | 78 +++++++++++++++---- .../process/phase/ptychographic_tomography.py | 31 ++++++-- .../phase/ptychographic_visualizations.py | 1 + .../process/phase/singleslice_ptychography.py | 22 ++++-- py4DSTEM/process/phase/utils.py | 64 +++++++++++++++ 12 files changed, 274 insertions(+), 72 deletions(-) diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index fb4983622..9db7895d3 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -576,7 +576,9 @@ def resample_data_diffraction( resampling_factor = np.array(output_size) / np.array(datacube.shape[-2:]) resampling_factor = np.concatenate(((1, 1), resampling_factor)) - datacube.data = zoom(datacube.data, resampling_factor, order=1) + datacube.data = zoom( + datacube.data, resampling_factor, order=1, mode="grid-wrap", grid_mode=True + ) datacube.calibration.set_Q_pixel_size( datacube.calibration.get_Q_pixel_size() / resampling_factor[2] ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 903a0363c..2d98b6834 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -236,7 +236,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_probe_overlaps: bool = True, @@ -265,9 +266,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -308,7 +312,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -350,11 +356,20 @@ def preprocess( roi_shape = self._datacube[0].Qshape if diffraction_intensities_shape is not None: roi_shape = diffraction_intensities_shape - if probe_roi_shape is not None: - roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) - self._region_of_interest_shape = np.array(roi_shape) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # TO-DO: generalize this if force_com_shifts is None: @@ -381,7 +396,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts[index], @@ -397,7 +412,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=None, dp_mask=None, com_shifts=force_com_shifts[index], diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 0e633ae71..5333960e9 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -207,7 +207,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_rotation: bool = True, @@ -244,10 +245,13 @@ def preprocess( Pixel dimensions (Qx',Qy') of the resampled diffraction intensities If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional - Method to use for reshaping, either 'bin', 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -291,7 +295,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -374,11 +380,20 @@ def preprocess( roi_shape = self._datacube[0].Qshape if diffraction_intensities_shape is not None: roi_shape = diffraction_intensities_shape - if probe_roi_shape is not None: - roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) - self._region_of_interest_shape = np.array(roi_shape) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # TO-DO: generalize this if force_com_shifts is None: @@ -408,7 +423,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts[index], @@ -424,7 +439,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=None, dp_mask=None, com_shifts=force_com_shifts[index], diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 931985224..630608013 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -276,7 +276,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -314,9 +315,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -363,7 +367,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -387,7 +393,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -460,7 +466,13 @@ def preprocess( # explicitly delete namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index dbcb62e97..177898a0f 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -217,7 +217,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -304,7 +305,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -328,7 +331,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -401,7 +404,13 @@ def preprocess( # explicitly delete namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 69ad11330..3dd5c34e2 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -251,7 +251,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -289,9 +290,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -338,7 +342,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -362,7 +368,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -435,7 +441,13 @@ def preprocess( # explicitly delete namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 7109e89a1..546dd7c38 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -136,7 +136,7 @@ def _preprocess_datacube_and_vacuum_probe( datacube, diffraction_intensities_shape=None, reshaping_method="fourier", - probe_roi_shape=None, + padded_diffraction_intensities_shape=None, vacuum_probe_intensity=None, dp_mask=None, com_shifts=None, @@ -153,13 +153,10 @@ def _preprocess_datacube_and_vacuum_probe( Note this does not affect the maximum scattering wavevector (Qx*dkx,Qy*dky) = (Sx*dkx',Sy*dky'), and thus the real-space sampling stays fixed. - The real space sampling, (dx, dy), combined with the resampled diffraction_intensities_shape, - sets the real-space probe region of interest (ROI) extent (dx*Sx, dy*Sy). - Occasionally, one may also want to specify a larger probe ROI extent, e.g when the probe - does not comfortably fit without self-ovelap artifacts, or when the scan step sizes are much - smaller than the real-space sampling (dx,dy). This can be achieved by specifying a - probe_roi_shape, which is larger than diffraction_intensities_shape, which will result in - zero-padding of the diffraction intensities. + Additionally, one may wish to zero-pad the diffraction intensity data. Note this does not increase + the information or resolution, but might be beneficial in a limited number of cases, e.g. when the + scan step sizes are much smaller than the real-space sampling (dx,dy). This can be achieved by specifying + a padded_diffraction_intensities_shape which is larger than diffraction_intensities_shape. Parameters ---------- @@ -170,7 +167,7 @@ def _preprocess_datacube_and_vacuum_probe( If None, no resamping is performed reshaping method: str, optional Reshaping method to use, one of 'bin', 'bilinear' or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape, (int,int), optional Padded diffraction intensities shape. If None, no padding is performed vacuum_probe_intensity, np.ndarray, optional @@ -284,10 +281,10 @@ def _preprocess_datacube_and_vacuum_probe( ) ) - if probe_roi_shape is not None: + if padded_diffraction_intensities_shape is not None: Qx, Qy = datacube.shape[-2:] - Sx, Sy = probe_roi_shape - datacube = datacube.pad_Q(output_size=probe_roi_shape) + Sx, Sy = padded_diffraction_intensities_shape + datacube = datacube.pad_Q(output_size=padded_diffraction_intensities_shape) if vacuum_probe_intensity is not None or dp_mask is not None: pad_kx = Sx - Qx diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index e8c52e781..81b6c7946 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -13,6 +13,7 @@ partition_list, rotate_point, spatial_frequencies, + vectorized_bilinear_resample, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex @@ -1215,6 +1216,7 @@ def show_probe( ] probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] show_complex( probe, @@ -1298,6 +1300,7 @@ def show_fourier_probe( ] probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] show_complex( probe, @@ -1497,15 +1500,28 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): """ xp = self._xp + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) exit_waves = modified_overlap - overlap + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + exit_waves = vectorized_bilinear_resample( + exit_waves, output_size=self._region_of_interest_shape, xp=xp + ) + return exit_waves, error def _projection_sets_fourier_projection( @@ -1553,18 +1569,30 @@ def _projection_sets_fourier_projection( if exit_waves is None: exit_waves = overlap.copy() - fourier_overlap = xp.fft.fft2(overlap) - farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) - error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + factor_to_be_projected = vectorized_bilinear_resample( + factor_to_be_projected, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) fourier_projected_factor = amplitudes * xp.exp( 1j * xp.angle(fourier_projected_factor) ) + projected_factor = xp.fft.ifft2(fourier_projected_factor) + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + projected_factor = vectorized_bilinear_resample( + projected_factor, output_size=self._region_of_interest_shape, xp=xp + ) + exit_waves = ( projection_x * exit_waves + projection_a * overlap @@ -2503,6 +2531,13 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): """ xp = self._xp + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) @@ -2515,6 +2550,12 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): exit_waves = modified_overlap - overlap + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + exit_waves = vectorized_bilinear_resample( + exit_waves, output_size=self._region_of_interest_shape, xp=xp + ) + return exit_waves, error def _projection_sets_fourier_projection( @@ -2562,23 +2603,30 @@ def _projection_sets_fourier_projection( if exit_waves is None: exit_waves = overlap.copy() - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + factor_to_be_projected = vectorized_bilinear_resample( + factor_to_be_projected, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes - amplitude_modification = amplitudes / intensity_norm_projected fourier_projected_factor *= amplitude_modification[:, None] - projected_factor = xp.fft.ifft2(fourier_projected_factor) + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + projected_factor = vectorized_bilinear_resample( + projected_factor, output_size=self._region_of_interest_shape, xp=xp + ) + exit_waves = ( projection_x * exit_waves + projection_a * overlap diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 7d4964fa4..13bdf035d 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -230,7 +230,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_probe_overlaps: bool = True, @@ -260,9 +261,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -307,7 +311,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -349,11 +355,20 @@ def preprocess( roi_shape = self._datacube[0].Qshape if diffraction_intensities_shape is not None: roi_shape = diffraction_intensities_shape - if probe_roi_shape is not None: - roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) - self._region_of_interest_shape = np.array(roi_shape) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # TO-DO: generalize this if force_com_shifts is None: @@ -380,7 +395,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts[index], @@ -396,7 +411,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=None, dp_mask=None, com_shifts=force_com_shifts[index], diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py index 916ef0352..014c380f0 100644 --- a/py4DSTEM/process/phase/ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -284,6 +284,7 @@ def _visualize_all_iterations( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, **kwargs, ) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 57b8ec93c..59535a19d 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -195,7 +195,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -233,9 +234,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -282,7 +286,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -306,7 +312,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -379,7 +385,13 @@ def preprocess( # explicitly delete intensities namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 8894faaf3..c9db5f86b 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -4,12 +4,14 @@ import matplotlib.pyplot as plt import numpy as np from scipy.fft import dctn, idctn +from scipy.ndimage import zoom from scipy.optimize import curve_fit try: import cupy as cp from cupyx.scipy.fft import dctn as dctn_cp from cupyx.scipy.fft import idctn as idctn_cp + from cupyx.scipy.ndimage import zoom as zoom_cp except (ImportError, ModuleNotFoundError): cp = None @@ -2108,6 +2110,67 @@ def lanczos_kernel_density_estimate( return pix_output +def vectorized_bilinear_resample( + array, + scale=None, + output_size=None, + mode="grid-wrap", + grid_mode=True, + xp=np, +): + """ + Resize an array along its final two axes. + Note, this is vectorized and thus very memory-intensive. + + The scaling of the array can be specified by passing either `scale`, which sets + the scaling factor along both axes to be scaled; or by passing `output_size`, + which specifies the final dimensions of the scaled axes. + + Parameters + ---------- + array: np.ndarray + Input array to be resampled + scale: float + Scalar value giving the scaling factor for all dimensions + output_size: (int,int) + Tuple of two values giving the output size for the final two axes + xp: Callable + Array computing module + + Returns + ------- + resampled_array: np.ndarray + Resampled array + """ + + array_size = np.array(array.shape) + input_size = array_size[-2:].copy() + + if scale is not None: + scale = np.array(scale) + if scale.size == 1: + scale = np.tile(scale, 2) + + output_size = (input_size * scale).astype("int") + else: + if output_size is None: + raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) + if output_size.size != 2: + raise ValueError("`output_size` must contain exactly two values.") + output_size = np.array(output_size) + + scale_output = tuple(output_size / input_size) + scale_output = (1,) * (array_size.size - input_size.size) + scale_output + + if xp is np: + array = zoom(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + else: + array = zoom_cp(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + + return array + + def vectorized_fourier_resample( array, scale=None, @@ -2153,6 +2216,7 @@ def vectorized_fourier_resample( else: if output_size is None: raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) if output_size.size != 2: raise ValueError("`output_size` must contain exactly two values.") output_size = np.array(output_size)