From 1fc9c52ae199148f02663a539de68981fa035e0c Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Tue, 24 Oct 2023 15:25:02 -0700 Subject: [PATCH 01/26] silly parallax bug --- py4DSTEM/process/phase/iterative_parallax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 74688fa0b..094209e7a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1642,6 +1642,9 @@ def score_CTF(coefs): ), UserWarning, ) + else: + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts From c41c386e64fc28bb8863f10c8289f94ee03c7d18 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 25 Oct 2023 15:17:51 -0700 Subject: [PATCH 02/26] some helpful deets --- py4DSTEM/process/phase/iterative_parallax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 094209e7a..21af22a37 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1789,6 +1789,7 @@ def score_CTF(coefs): ) print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Transpose = {self.transpose_detected}") if fit_CTF_FFT or fit_BF_shifts: print() From 20bc04157e1f48ea5f359ae3c36fd0854b40fc94 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 26 Oct 2023 11:23:56 -0700 Subject: [PATCH 03/26] middle focus for multislice --- .../iterative_multislice_ptychography.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 4515590fe..77a5c69ea 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -81,9 +81,11 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): Probe positions in Å for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in angles) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in angles) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -117,6 +119,7 @@ def __init__( initial_scan_positions: np.ndarray = None, theta_x: float = 0, theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", verbose: bool = True, device: str = "cpu", @@ -150,6 +153,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: From f9e2423084daea1602142f5d869ae983f6d4f4e5 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Fri, 27 Oct 2023 15:18:06 -0700 Subject: [PATCH 04/26] re-introducing probe intensity normalizations into constraints --- .../iterative_ptychographic_constraints.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 3eebdb068..0760087b4 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -433,8 +433,8 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - # probe_intensity = xp.abs(current_probe) ** 2 - # current_probe_sum = xp.sum(probe_intensity) + probe_intensity = xp.abs(current_probe) ** 2 + current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -444,10 +444,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_fourier_amplitude_constraint( self, @@ -476,7 +476,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -489,10 +489,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aperture_constraint( self, @@ -514,16 +514,16 @@ def _probe_aperture_constraint( """ xp = self._xp - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aberration_fitting_constraint( self, From 27a1c962335e757bcd2a6ebfcd0bff175cdfeedc Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 30 Oct 2023 09:03:51 -0700 Subject: [PATCH 05/26] bug in depth plotting --- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 77a5c69ea..6bcacd934 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3097,7 +3097,7 @@ def show_depth( rotated_object = np.roll( rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - int(x1_0), + -int(x1_0), axis=1, ) From 701b2755f76e21262cd74523e92402bf6d4ea176 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 30 Oct 2023 14:05:40 -0700 Subject: [PATCH 06/26] minor dpc bugfixes --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- py4DSTEM/process/phase/iterative_dpc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 04cfd6a60..f04a3c552 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1257,7 +1257,7 @@ def show_complex_CoM( if pixelsize is None: pixelsize = self._scan_sampling[0] if pixelunits is None: - pixelunits = r"$\AA$" + pixelunits = self._scan_units[0] figsize = kwargs.pop("figsize", (6, 6)) fig, ax = plt.subplots(figsize=figsize) diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index af3cbbb45..b390ce46d 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -799,6 +799,7 @@ def reconstruct( anti_gridding=anti_gridding, ) + self.error_iterations.append(self.error.item()) if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -807,7 +808,6 @@ def reconstruct( ].copy() ) ) - self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: if self._verbose: From dcc62a3dcc9edeac9f0e7f4daa97a047e2ce9ed0 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 30 Oct 2023 14:06:15 -0700 Subject: [PATCH 07/26] parallax DF limit bug, cropped property --- py4DSTEM/process/phase/iterative_parallax.py | 62 ++++++++++++++++---- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 21af22a37..daab204a0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1098,26 +1098,46 @@ def subpixel_alignment( BF_size = np.array(self._stack_BF_no_window.shape[-2:]) self._DF_upsample_limit = np.max( - self._region_of_interest_shape / self._scan_shape + 2 * self._region_of_interest_shape / self._scan_shape ) self._BF_upsample_limit = ( - 2 * self._kr.max() / self._reciprocal_sampling[0] + 4 * self._kr.max() / self._reciprocal_sampling[0] ) / self._scan_shape.max() if self._device == "gpu": self._BF_upsample_limit = self._BF_upsample_limit.item() if kde_upsample_factor is None: - kde_upsample_factor = np.minimum( - self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit - ) + if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit: + kde_upsample_factor = self._DF_upsample_limit - warnings.warn( - ( - f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " - f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." - ), - UserWarning, - ) + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (the " + f"dark-field upsampling limit)." + ), + UserWarning, + ) + + elif self._BF_upsample_limit * 3 / 2 > 1: + kde_upsample_factor = self._BF_upsample_limit * 3 / 2 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + else: + kde_upsample_factor = self._DF_upsample_limit * 2 / 3 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (2/3 times the " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f})." + ), + UserWarning, + ) if kde_upsample_factor < 1: raise ValueError("kde_upsample_factor must be larger than 1") @@ -2349,3 +2369,21 @@ def visualize( ax.set_title("Reconstructed Bright Field Image") return self + + @property + def object_cropped(self): + """cropped object""" + if hasattr(self, "_recon_phase_corrected"): + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_phase_corrected, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_phase_corrected) + else: + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_BF) From b7a7a5f15b589f573bad8491e12888585f813ada Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 30 Oct 2023 14:06:56 -0700 Subject: [PATCH 08/26] complex grid scalebar bug --- py4DSTEM/visualize/vis_special.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index d1efbd023..388b57e0a 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -829,7 +829,7 @@ def show_complex( for ax_flat in ax.flatten(): divider = make_axes_locatable(ax_flat) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") From b67a06422130c884cb4d6c951fa9b65f49d3f069 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 1 Nov 2023 12:00:36 -0700 Subject: [PATCH 09/26] adding self_consistency_errors property. not implemented for 3D yet --- .../process/phase/iterative_base_class.py | 31 ++++++++++++++++++- ...tive_mixedstate_multislice_ptychography.py | 28 +++++++++++++++++ .../iterative_mixedstate_ptychography.py | 28 +++++++++++++++++ .../iterative_overlap_magnetic_tomography.py | 5 +++ .../phase/iterative_overlap_tomography.py | 5 +++ .../iterative_simultaneous_ptychography.py | 29 +++++++++++++++++ 6 files changed, 125 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index f04a3c552..13c64d79d 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2366,6 +2366,35 @@ def positions(self): @property def object_cropped(self): - """cropped and rotated object""" + """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 3eeb07814..82155219a 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3509,3 +3509,31 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) + error /= self._mean_diffraction_intensity + + return asnumpy(error) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 2e9fbd076..25bee346c 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -2327,3 +2327,31 @@ def show_fourier_probe( chroma_boost=chroma_boost, **kwargs, ) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) + error /= self._mean_diffraction_intensity + + return asnumpy(error) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 32b0f6fd4..cde84907c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3327,3 +3327,8 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 66cf46487..e92211301 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3207,3 +3207,8 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 37438852f..757b2ffae 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -3357,3 +3357,32 @@ def visualize( ) return self + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) From 9f82c20fb4a44c158270c286c092d48fb220053e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 14:00:54 -0700 Subject: [PATCH 10/26] real space mask for positions to ignore --- py4DSTEM/process/phase/iterative_base_class.py | 15 ++++++++++++++- ...terative_mixedstate_multislice_ptychography.py | 6 +++++- .../phase/iterative_mixedstate_ptychography.py | 6 +++++- .../phase/iterative_multislice_ptychography.py | 6 +++++- .../iterative_overlap_magnetic_tomography.py | 6 +++++- .../process/phase/iterative_overlap_tomography.py | 6 +++++- .../phase/iterative_simultaneous_ptychography.py | 6 +++++- .../phase/iterative_singleslice_ptychography.py | 6 +++++- 8 files changed, 49 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 13c64d79d..476216f79 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1535,7 +1535,9 @@ def _set_polar_parameters(self, parameters: dict): else: raise ValueError("{} not a recognized parameter".format(symbol)) - def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): + def _calculate_scan_positions_in_pixels( + self, positions: np.ndarray, positions_mask + ): """ Method to compute the initial guess of scan positions in pixels. @@ -1544,6 +1546,8 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions: (J,2) np.ndarray or None Input probe positions in Å. If None, a raster scan using experimental parameters is constructed. + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1592,6 +1596,15 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions = np.array([x.ravel(), y.ravel()]).T positions -= np.min(positions, axis=0) + if positions_mask is not None: + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converged to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + positions = positions[positions_mask.ravel()] + if self._object_padding_px is None: float_padding = self._region_of_interest_shape / 2 self._object_padding_px = (float_padding, float_padding) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 82155219a..98967ba89 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -85,6 +85,8 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -115,6 +117,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -201,6 +204,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -454,7 +458,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 25bee346c..195dace86 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -74,6 +74,8 @@ class MixedstatePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -102,6 +104,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "mixed-state_ptychographic_reconstruction", @@ -178,6 +181,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -358,7 +362,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 6bcacd934..a137bbeb9 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -89,6 +89,8 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -121,6 +123,7 @@ def __init__( theta_y: float = 0, middle_focus: bool = False, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -211,6 +214,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -481,7 +485,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index cde84907c..b4501d012 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -93,6 +93,8 @@ class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -115,6 +117,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -179,6 +182,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -615,7 +619,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index e92211301..759b12602 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -88,6 +88,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction name: str, optional Class name kwargs: @@ -111,6 +113,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -188,6 +191,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -555,7 +559,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 757b2ffae..35b2bb9ef 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -66,6 +66,8 @@ class SimultaneousPtychographicReconstruction(PtychographicReconstruction): object_padding_px: Tuple[int,int], optional Pixel dimensions to pad objects with If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction initial_object_guess: np.ndarray, optional Initial guess for complex-valued object of dimensions (Px,Py) If None, initialized to 1.0j @@ -102,6 +104,7 @@ def __init__( vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, @@ -167,6 +170,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -607,7 +611,7 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 5dd19d7bd..8e66639b2 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -79,6 +79,8 @@ class SingleslicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -102,6 +104,7 @@ def __init__( initial_scan_positions: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "ptychographic_reconstruction", @@ -163,6 +166,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -342,7 +346,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels From 67e15e7002234c0540cc9a69d8a6c60ff0d4c471 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 14:40:00 -0700 Subject: [PATCH 11/26] amplitudes update for real space mask --- .../process/phase/iterative_base_class.py | 12 +++++----- ...tive_mixedstate_multislice_ptychography.py | 13 +++++++++- .../iterative_mixedstate_ptychography.py | 12 +++++++++- .../iterative_multislice_ptychography.py | 12 +++++++++- .../iterative_overlap_magnetic_tomography.py | 13 +++++++++- .../phase/iterative_overlap_tomography.py | 13 +++++++++- .../iterative_simultaneous_ptychography.py | 24 ++++++++++++++++--- .../iterative_singleslice_ptychography.py | 13 +++++++++- 8 files changed, 97 insertions(+), 15 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 476216f79..73021d8a9 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1132,6 +1132,7 @@ def _normalize_diffraction_intensities( com_fitted_x, com_fitted_y, crop_patterns, + positions_mask, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1147,6 +1148,8 @@ def _normalize_diffraction_intensities( crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1220,6 +1223,9 @@ def _normalize_diffraction_intensities( amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) amplitudes = xp.asarray(amplitudes) + if positions_mask is not None: + amplitudes = amplitudes[positions_mask.ravel()] + mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity @@ -1597,12 +1603,6 @@ def _calculate_scan_positions_in_pixels( positions -= np.min(positions, axis=0) if positions_mask is not None: - if positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converged to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") positions = positions[positions_mask.ravel()] if self._object_padding_px is None: diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 98967ba89..2915acccb 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -189,6 +189,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -449,7 +456,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 195dace86..01d70bf71 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -164,6 +164,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -353,7 +359,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a137bbeb9..be24f067d 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -198,6 +198,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -476,7 +482,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index b4501d012..810352ce8 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -166,6 +166,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -599,7 +606,11 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, com_fitted_x, com_fitted_y, crop_patterns + intensities, + com_fitted_x, + com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 759b12602..701267e81 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -175,6 +175,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -539,7 +546,11 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, com_fitted_x, com_fitted_y, crop_patterns + intensities, + com_fitted_x, + com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 35b2bb9ef..ae1a3ecac 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -153,6 +153,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -408,7 +414,11 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns + intensities_0, + com_fitted_x_0, + com_fitted_y_0, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -489,7 +499,11 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns + intensities_1, + com_fitted_x_1, + com_fitted_y_1, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -571,7 +585,11 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns + intensities_2, + com_fitted_x_2, + com_fitted_y_2, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 8e66639b2..ab16330da 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -150,6 +150,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -337,7 +344,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace From 2d48616c7e5a0e83ad2f038c97c35fb6d4ddad24 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 15:51:05 -0700 Subject: [PATCH 12/26] Thnks fr th Mmr(s) --- .../process/phase/iterative_base_class.py | 24 ++++++++++++------- ...tive_mixedstate_multislice_ptychography.py | 2 +- .../iterative_mixedstate_ptychography.py | 2 +- .../iterative_multislice_ptychography.py | 2 +- .../iterative_overlap_magnetic_tomography.py | 2 +- .../phase/iterative_overlap_tomography.py | 2 +- .../iterative_simultaneous_ptychography.py | 2 +- .../iterative_singleslice_ptychography.py | 2 +- 8 files changed, 23 insertions(+), 15 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 73021d8a9..497c7ae1c 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1163,6 +1163,12 @@ def _normalize_diffraction_intensities( mean_intensity = 0 diffraction_intensities = self._asnumpy(diffraction_intensities) + if positions_mask is not None: + number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) + sx, sy = np.where(~self._positions_mask) + else: + number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + if crop_patterns: crop_x = int( np.minimum( @@ -1181,8 +1187,7 @@ def _normalize_diffraction_intensities( region_of_interest_shape = (crop_w * 2, crop_w * 2) amplitudes = np.zeros( ( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + number_of_patterns, crop_w * 2, crop_w * 2, ), @@ -1198,13 +1203,19 @@ def _normalize_diffraction_intensities( else: region_of_interest_shape = diffraction_intensities.shape[-2:] - amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) + amplitudes = np.zeros( + (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + ) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) + counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): + if positions_mask is not None: + if rx in sx and ry in sy: + continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], -com_fitted_x[rx, ry], @@ -1219,13 +1230,10 @@ def _normalize_diffraction_intensities( ) mean_intensity += np.sum(intensities) - amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) amplitudes = xp.asarray(amplitudes) - if positions_mask is not None: - amplitudes = amplitudes[positions_mask.ravel()] - mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2915acccb..26b0d8cff 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -189,7 +189,7 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 01d70bf71..ebc40928d 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -164,7 +164,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index be24f067d..73f83558e 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -198,7 +198,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 810352ce8..582eea772 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -166,7 +166,7 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 701267e81..f4dfe5022 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -175,7 +175,7 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index ae1a3ecac..866ff0a89 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -153,7 +153,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index ab16330da..350d0a3cb 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -150,7 +150,7 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, From 3da6fdc3cda11baba1289abbd167ffa2d42627e5 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 16:32:24 -0700 Subject: [PATCH 13/26] one more bug --- py4DSTEM/process/phase/iterative_base_class.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 497c7ae1c..1aa03559a 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1165,7 +1165,6 @@ def _normalize_diffraction_intensities( diffraction_intensities = self._asnumpy(diffraction_intensities) if positions_mask is not None: number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) - sx, sy = np.where(~self._positions_mask) else: number_of_patterns = np.prod(diffraction_intensities.shape[:2]) @@ -1214,7 +1213,7 @@ def _normalize_diffraction_intensities( for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): if positions_mask is not None: - if rx in sx and ry in sy: + if not self._positions_mask[rx,ry]: continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], From 7a4e7a43e926c48aa0643b60bb1d0202e8aa65ea Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 16:34:08 -0700 Subject: [PATCH 14/26] black format --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 1aa03559a..7437679a2 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1213,7 +1213,7 @@ def _normalize_diffraction_intensities( for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): if positions_mask is not None: - if not self._positions_mask[rx,ry]: + if not self._positions_mask[rx, ry]: continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], From eabd74257d553caf633e7dae2ecd1bd535e9c84f Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 11:39:43 -0700 Subject: [PATCH 15/26] I've been plotting to update this function --- .../process/phase/iterative_base_class.py | 6 -- ...tive_mixedstate_multislice_ptychography.py | 70 ++++++++++++++++++- .../iterative_multislice_ptychography.py | 34 +++++++-- .../iterative_overlap_magnetic_tomography.py | 6 -- .../phase/iterative_overlap_tomography.py | 6 -- 5 files changed, 96 insertions(+), 26 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 7437679a2..56be2784a 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2306,22 +2306,16 @@ def show_object_fft(self, obj=None, **kwargs): figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 26b0d8cff..6cbdca19e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -82,6 +82,12 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -116,6 +122,9 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", positions_mask: np.ndarray = None, verbose: bool = True, @@ -165,6 +174,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: @@ -221,6 +249,8 @@ def __init__( self._num_probes = num_probes self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -243,6 +273,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) Returns ------- @@ -262,6 +296,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -269,6 +307,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -3075,6 +3119,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -3090,12 +3135,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -3113,8 +3166,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 73f83558e..4b0d6881c 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -81,9 +81,9 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): Probe positions in Å for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in degrees) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in degrees) middle_focus: bool if True, adds half the sample thickness to the defocus object_type: str, optional @@ -256,9 +256,9 @@ def _precompute_propagator_arrays( slice_thicknesses: Sequence[float] Array of slice thicknesses in A theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in degrees) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in degrees) Returns ------- @@ -2955,6 +2955,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -2970,12 +2971,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -2993,8 +3002,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.mean(0).ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 582eea772..7c96cb34c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3303,22 +3303,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index f4dfe5022..54b94010a 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3183,22 +3183,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) From 3a6ee5a80c70cd9cf4c9d94b662c01bb82df8ce7 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 11:46:23 -0700 Subject: [PATCH 16/26] correct propagation of arguments --- .../phase/iterative_mixedstate_multislice_ptychography.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6cbdca19e..6cd74828e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -258,6 +258,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -656,6 +658,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps From a51594c80b9c22344760ab287b3d8f2a36492cb0 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 11:56:39 -0700 Subject: [PATCH 17/26] one more bug fix --- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 4b0d6881c..764f0b4a0 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3004,7 +3004,7 @@ def show_slices( cmap = kwargs.pop("cmap", "magma") if common_color_scale: - vals = np.sort(rotated_object.mean(0).ravel()) + vals = np.sort(rotated_object.ravel()) ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") ind_vmin = np.max([0, ind_vmin]) From 8323bc8f30d35a98de32e7521bfc3616a0d95706 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 17:51:30 -0700 Subject: [PATCH 18/26] fft hanning window --- py4DSTEM/visualize/show.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 4e99c0de5..b6077c412 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -366,7 +366,9 @@ def show( from py4DSTEM.visualize import show if show_fft: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) for a0 in range(num_images): im = show( ar[a0], From 9d5e83d14a6414d153014d29936e28ddd140e7c3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 Nov 2023 16:19:54 -0700 Subject: [PATCH 19/26] ctf transpose bugfix - tested mostly for stig --- py4DSTEM/process/phase/iterative_parallax.py | 364 +++++++++++-------- 1 file changed, 209 insertions(+), 155 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index daab204a0..3758a64e8 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -136,7 +136,7 @@ def to_h5(self, group): if hasattr(self, "aberration_C1"): recon_metadata |= { "aberration_rotation_QR": self.rotation_Q_to_R_rads, - "aberration_transpose": self.transpose_detected, + "aberration_transpose": self.transpose, "aberration_C1": self.aberration_C1, "aberration_A1x": self.aberration_A1x, "aberration_A1y": self.aberration_A1y, @@ -236,7 +236,7 @@ def _populate_instance(self, group): if "aberration_C1" in reconstruction_md.keys: self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] - self.transpose_detected = reconstruction_md["aberration_transpose"] + self.transpose = reconstruction_md["aberration_transpose"] self.aberration_C1 = reconstruction_md["aberration_C1"] self.aberration_A1x = reconstruction_md["aberration_A1x"] self.aberration_A1y = reconstruction_md["aberration_A1y"] @@ -1321,7 +1321,7 @@ def aberration_fit( plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, upsampled: bool = True, - force_transpose: bool = None, + force_transpose: bool = False, ): """ Fit aberrations to the measured image shifts. @@ -1362,17 +1362,13 @@ def aberration_fit( # Convert real space shifts to Angstroms - if force_transpose is None: - self.transpose_detected = False - else: - self.transpose_detected = force_transpose - if force_transpose is True: self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( self._scan_sampling ) else: self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + self.transpose = force_transpose # Solve affine transformation m = asnumpy( @@ -1389,9 +1385,15 @@ def aberration_fit( np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi ) m_aberration = -1.0 * m_aberration + self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + + if self.transpose: + self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + else: + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 ### Second pass @@ -1437,12 +1439,26 @@ def aberration_fit( sx = self._scan_sampling[0] / self._kde_upsample_factor sy = self._scan_sampling[1] / self._kde_upsample_factor + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + else: im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) sx = self._scan_sampling[0] sy = self._scan_sampling[1] upsampled = False + reciprocal_extent = [ + -0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[0], + -0.5 / self._scan_sampling[0], + ] + # FFT coordinates qx = xp.fft.fftfreq(im_FFT.shape[0], sx) qy = xp.fft.fftfreq(im_FFT.shape[1], sy) @@ -1494,12 +1510,19 @@ def calculate_CTF_FFT(alpha_shape, *coefs): sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) - qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + qx, qy = np.meshgrid(qx, qy, indexing="ij") + + # passive rotation basis by -theta + rotation_angle = -self.rotation_Q_to_R_rads + qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( + rotation_angle + ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) - u = qx[:, None] * self._wavelength - v = qy[None, :] * self._wavelength + qr2 = qx**2 + qy**2 + u = qx * self._wavelength + v = qy * self._wavelength alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None, :], qx[:, None]) + theta = xp.arctan2(qy, qx) # Aberration basis self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) @@ -1561,10 +1584,17 @@ def calculate_CTF(alpha_shape, *coefs): # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) - ind = np.argmin( - np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] - ) - self._aberrations_coefs[ind] = self.aberration_C1 + + aberrations_mn_list = self._aberrations_mn.tolist() + if [1, 0, 0] in aberrations_mn_list: + ind_C1 = aberrations_mn_list.index([1, 0, 0]) + self._aberrations_coefs[ind_C1] = self.aberration_C1 + + if [1, 2, 0] in aberrations_mn_list: + ind_A1x = aberrations_mn_list.index([1, 2, 0]) + ind_A1y = aberrations_mn_list.index([1, 2, 1]) + self._aberrations_coefs[ind_A1x] = self.aberration_A1x + self._aberrations_coefs[ind_A1y] = self.aberration_A1y # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: @@ -1617,57 +1647,84 @@ def score_CTF(coefs): ) # (Relative) untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + raveled_shifts = self._xy_shifts_Ang.T.ravel() aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None + gradients, raveled_shifts, rcond=None )[:2] - if force_transpose is None: - # (Relative) transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) - m_T = asnumpy( - xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ - 0 - ] + self._aberrations_coefs = asnumpy(aberrations_coefs) + + if self.transpose: + aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( + self._aberrations_mn[:, 2] == 0 ) - m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2( - m_rotation_T[1, 0], m_rotation_T[0, 0] + self._aberrations_coefs[aberrations_to_flip] *= -1 + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 ) - if np.abs( - np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi - ) > (np.pi * 0.5): - rotation_Q_to_R_rads_T = ( - np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] + + fitted_shifts = ( + xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) + .reshape((2, -1)) + .T + ) + + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] - tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq( - gradients, rotated_shifts_T, rcond=None - )[:2] - - # Compare fits - if res_T.sum() < res.sum(): - self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = not self.transpose_detected - self._aberrations_coefs = asnumpy(aberrations_coefs_T) - self._rotated_shifts = rotated_shifts_T - - warnings.warn( - ( - "Data transpose detected. " - f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" - ), - UserWarning, + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] + + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] ) - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts + ) + + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], + [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Fitted Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Horizontal Shifts", + ], + ) # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: @@ -1705,79 +1762,24 @@ def score_CTF(coefs): im_plot[:, :, 2] -= im_CTF im_plot = np.clip(im_plot, 0, 1) - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) - ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) - - ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") - - fig.tight_layout() - - # Plot the measured/fitted shifts comparison - if plot_BF_shifts_comparison: - if not fit_BF_shifts: - raise ValueError() - - measured_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._rotated_shifts[: self._xy_inds.shape[0]] - - measured_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._rotated_shifts[self._xy_inds.shape[0] :] - - fitted_shifts = xp.tensordot( - gradients, xp.array(self._aberrations_coefs), axes=1 + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) + ax1.imshow( + im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent ) - - fitted_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap="gray", + extent=reciprocal_extent, ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ - : self._xy_inds.shape[0] - ] - fitted_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ - self._xy_inds.shape[0] : - ] + for ax in (ax1, ax2): + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") - max_shift = xp.max( - xp.array( - [ - xp.abs(measured_shifts_sx).max(), - xp.abs(measured_shifts_sy).max(), - xp.abs(fitted_shifts_sx).max(), - xp.abs(fitted_shifts_sy).max(), - ] - ) - ) + ax1.set_title("Aligned Bright Field FFT") + ax2.set_title("Fitted CTF Zero-Crossings") - show( - [ - [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], - [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], - ], - cmap="PiYG", - vmin=-max_shift, - vmax=max_shift, - intensity_range="absolute", - axsize=(4, 4), - ticks=False, - title=[ - "Measured Vertical Shifts", - "Measured Horizontal Shifts", - "Fitted Vertical Shifts", - "Fitted Horizontal Shifts", - ], - ) + fig.tight_layout() self.aberration_dict = { tuple(self._aberrations_mn[a0]): { @@ -1809,7 +1811,7 @@ def score_CTF(coefs): ) print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - print(f"Transpose = {self.transpose_detected}") + print(f"Transpose = {self.transpose}") if fit_CTF_FFT or fit_BF_shifts: print() @@ -2292,6 +2294,7 @@ def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, + plot_rotated_shifts=True, **kwargs, ): """ @@ -2308,10 +2311,22 @@ def show_shifts( xp = self._xp asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (6, 6)) color = kwargs.pop("color", (1, 0, 0, 1)) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + scaling_factor = ( + xp.array(self._reciprocal_sampling) + / xp.array(self._scan_sampling) + * scale_arrows + ) + rotated_shifts = self._xy_shifts_Ang * scaling_factor - fig, ax = plt.subplots(figsize=figsize) + else: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + + shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -2321,29 +2336,68 @@ def show_shifts( masked_ind = xp.logical_and(freq_mask, self._dp_mask) plot_ind = masked_ind[dp_mask_ind] - ax.quiver( - asnumpy(self._kxy[plot_ind, 1]), - asnumpy(self._kxy[plot_ind, 0]), - asnumpy( - self._xy_shifts[plot_ind, 1] - * scale_arrows - * self._reciprocal_sampling[0] - ), - asnumpy( - self._xy_shifts[plot_ind, 0] - * scale_arrows - * self._reciprocal_sampling[1] - ), - color=color, - angles="xy", - scale_units="xy", - scale=1, - **kwargs, - ) - kr_max = xp.max(self._kr) - ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) - ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + ax[0].quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[0].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_title("Measured Bright Field Shifts") + ax[0].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[0].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[0].set_aspect("equal") + + # passive coordinate rotation + tf_T = AffineTransform(angle=-self.rotation_Q_to_R_rads) + rotated_kxy = tf_T(self._kxy[plot_ind], xp=xp) + ax[1].quiver( + asnumpy(rotated_kxy[:, 1]), + asnumpy(rotated_kxy[:, 0]), + asnumpy(rotated_shifts[plot_ind, 1]), + asnumpy(rotated_shifts[plot_ind, 0]), + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[1].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_title("Rotated Bright Field Shifts") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[1].set_aspect("equal") + else: + ax.quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_title("Measured BF Shifts") + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.set_aspect("equal") + + fig.tight_layout() def visualize( self, From 2e59d1c7b52c509e5eb5164f6af5a58da5457975 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 Nov 2023 16:20:11 -0700 Subject: [PATCH 20/26] making ptycho aberration fitting convention consistent --- py4DSTEM/process/phase/iterative_ptychographic_constraints.py | 2 +- py4DSTEM/process/phase/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 0760087b4..d29aa1747 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -566,7 +566,7 @@ def _probe_aberration_fitting_constraint( xp=xp, ) - fourier_probe = fourier_probe_abs * xp.exp(1.0j * fitted_angle) + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) current_probe = xp.fft.ifft2(fourier_probe) return current_probe diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index d29765d04..a1eb54c80 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1620,7 +1620,7 @@ def fit_aberration_surface( ): """ """ probe_amp = xp.abs(complex_probe) - probe_angle = xp.angle(complex_probe) + probe_angle = -xp.angle(complex_probe) if xp is np: probe_angle = probe_angle.astype(np.float64) From d32b18dc3151c8cf5c457c206c094d804f7b84b7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 12:37:30 -0700 Subject: [PATCH 21/26] update uncertainty viz --- .../process/phase/iterative_base_class.py | 257 +++++++++++++++--- py4DSTEM/visualize/vis_special.py | 31 +++ 2 files changed, 257 insertions(+), 31 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 56be2784a..0f342d5c8 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -8,7 +8,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import show, show_complex +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import rotate try: @@ -23,7 +23,11 @@ from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( PtychographicConstraints, ) -from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases +from py4DSTEM.process.phase.utils import ( + AffineTransform, + generate_batches, + polar_aliases, +) from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -2237,6 +2241,226 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(asnumpy(obj)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(self.object_cropped.shape) + asnumpy(2 * padding) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(xp.asarray(errors), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") + pix_output /= pix_count + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5.25) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized Squared Error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + cropped_object_angle, vmin, vmax = return_scaled_histogram_ordering( + np.angle(self.object_cropped), + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * cropped_object_angle.shape[1], + self.sampling[0] * cropped_object_angle.shape[0], + 0, + ] + + ax.imshow( + cropped_object_angle, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + def show_fourier_probe( self, probe=None, @@ -2383,32 +2607,3 @@ def object_cropped(self): """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) - - @property - def self_consistency_errors(self): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - - # Normalized mean-squared errors - error = xp.sum( - xp.abs(self._amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - error /= self._mean_diffraction_intensity - - return asnumpy(error) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 388b57e0a..da501c746 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -839,3 +839,34 @@ def show_complex( if returnfig: return fig, ax + + +def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + if vmin is None: + vmin = 0.02 + if vmax is None: + vmax = 0.98 + + vals = np.sort(array.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + + scaled_array = array.copy() + scaled_array = np.where(scaled_array < vmin, vmin, scaled_array) + scaled_array = np.where(scaled_array > vmax, vmax, scaled_array) + + if normalize: + scaled_array -= scaled_array.min() + scaled_array /= scaled_array.max() + vmin = 0 + vmax = 1 + + return scaled_array, vmin, vmax From f93576a5dad24f7816c9f4bd72010bd393b4cbff Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 12:50:13 -0700 Subject: [PATCH 22/26] generalizing to accommodate other classes easier --- .../process/phase/iterative_base_class.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 0f342d5c8..772f6b133 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2287,10 +2287,22 @@ def _return_self_consistency_errors( return asnumpy(errors) + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped) + else: + projected_cropped_potential = self.object_cropped + + return projected_cropped_potential + def show_uncertainty_visualization( self, errors=None, max_batch_size=None, + projected_cropped_potential=None, kde_sigma=None, plot_histogram=True, plot_contours=False, @@ -2301,6 +2313,9 @@ def show_uncertainty_visualization( if errors is None: errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + if kde_sigma is None: kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] @@ -2323,7 +2338,9 @@ def show_uncertainty_visualization( padding = xp.min(rotated_points, axis=0).astype("int") # bilinear sampling - pixel_output = np.array(self.object_cropped.shape) + asnumpy(2 * padding) + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) pixel_size = pixel_output.prod() xa = rotated_points[:, 0] @@ -2415,21 +2432,21 @@ def show_uncertainty_visualization( vmin = kwargs.pop("vmin", None) vmax = kwargs.pop("vmax", None) - cropped_object_angle, vmin, vmax = return_scaled_histogram_ordering( - np.angle(self.object_cropped), + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, vmin=vmin, vmax=vmax, ) extent = [ 0, - self.sampling[1] * cropped_object_angle.shape[1], - self.sampling[0] * cropped_object_angle.shape[0], + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], 0, ] ax.imshow( - cropped_object_angle, + projected_cropped_potential, vmin=vmin, vmax=vmax, extent=extent, From 71cde33e67b5d4828e5b78456c7dbfd6af7c932b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 13:47:19 -0700 Subject: [PATCH 23/26] add uncertainty viz to all classes except OT --- ...tive_mixedstate_multislice_ptychography.py | 66 +++++++++++++----- .../iterative_mixedstate_ptychography.py | 55 ++++++++++----- .../iterative_multislice_ptychography.py | 11 +++ .../iterative_overlap_magnetic_tomography.py | 25 ++++++- .../phase/iterative_overlap_tomography.py | 25 ++++++- .../iterative_simultaneous_ptychography.py | 69 +++++++++++++++++++ 6 files changed, 211 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6cd74828e..f4c10cb13 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3595,30 +3595,60 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" xp = self._xp asnumpy = self._asnumpy - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) - # Normalized mean-squared errors - error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) - error /= self._mean_diffraction_intensity + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) - return asnumpy(error) + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index ebc40928d..d68291143 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -2342,30 +2342,49 @@ def show_fourier_probe( **kwargs, ) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" xp = self._xp asnumpy = self._asnumpy - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) - # Normalized mean-squared errors - error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) - error /= self._mean_diffraction_intensity + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity - return asnumpy(error) + return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 764f0b4a0..93e32b079 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3426,3 +3426,14 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 7c96cb34c..c49a1faac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3337,7 +3337,28 @@ def positions(self): return np.asarray(positions_all) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 54b94010a..ddd13ac58 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3217,7 +3217,28 @@ def positions(self): return np.asarray(positions_all) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 866ff0a89..233d34e45 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -351,6 +351,9 @@ def preprocess( ) ) + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + # 1st measurement sets rotation angle and transposition ( measurement_0, @@ -3408,3 +3411,69 @@ def self_consistency_errors(self): error /= self._mean_diffraction_intensity return asnumpy(error) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[0][start:end] + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped[0]) + else: + projected_cropped_potential = self.object_cropped[0] + + return projected_cropped_potential + + @property + def object_cropped(self): + """Cropped and rotated object""" + + obj_e, obj_m = self._object + obj_e = self._crop_rotate_object_fov(obj_e) + obj_m = self._crop_rotate_object_fov(obj_m) + return (obj_e, obj_m) From 4b227dc9bd8ebaa9d53b3792e489b741ee97cad9 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 13:47:58 -0700 Subject: [PATCH 24/26] small kde parallax bug --- py4DSTEM/process/phase/iterative_parallax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 3758a64e8..9f690c434 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1207,7 +1207,7 @@ def subpixel_alignment( # kernel density estimate sigma = kde_sigma * self._kde_upsample_factor pix_count = gaussian_filter(pix_count, sigma) - pix_count[pix_output == 0.0] = np.inf + pix_count[pix_count == 0.0] = np.inf pix_output = gaussian_filter(pix_output, sigma) pix_output /= pix_count From bab740671565f471474602c98fff1fc4f63251dc Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 5 Nov 2023 09:17:23 -0800 Subject: [PATCH 25/26] more parallax plotting fun(ctionality) --- py4DSTEM/process/phase/iterative_parallax.py | 23 +++++++++++++----- py4DSTEM/visualize/vis_special.py | 25 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 9f690c434..716e1d782 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -587,16 +587,27 @@ def preprocess( self.recon_BF = asnumpy(self._recon_BF) if plot_average_bf: - figsize = kwargs.pop("figsize", (6, 6)) + figsize = kwargs.pop("figsize", (6, 12)) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(1, 2, figsize=figsize) - self._visualize_figax(fig, ax, **kwargs) + self._visualize_figax(fig, ax[0], **kwargs) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Average Bright Field Image") + ax[0].set_ylabel("x [A]") + ax[0].set_xlabel("y [A]") + ax[0].set_title("Average Bright Field Image") + reciprocal_extent = [ + -0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + ] + ax[1].imshow(self._dp_mask, extent=reciprocal_extent, cmap="gray") + ax[1].set_title("DP mask") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + plt.tight_layout() self._preprocessed = True if self._device == "gpu": diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index da501c746..1d46ebf44 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -842,6 +842,31 @@ def show_complex( def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + """ + Utility function for calculating min and max values for plotting array + based on distribution of pixel values + + Parameters + ---------- + array: np.array + array to be plotted + vmin: float + lower fraction cut off of pixel values + vmax: float + upper fraction cut off of pixel values + normalize: bool + if True, rescales from 0 to 1 + + Returns + ---------- + scaled_array: np.array + array clipped outside vmin and vmax + vmin: float + lower value to be plotted + vmax: float + upper value to be plotted + """ + if vmin is None: vmin = 0.02 if vmax is None: From dd09924a14433ab7e86bd726b649440421d97e5e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 5 Nov 2023 09:20:27 -0800 Subject: [PATCH 26/26] black formatting --- py4DSTEM/visualize/vis_special.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 1d46ebf44..acacb6184 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -850,10 +850,10 @@ def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=Fals ---------- array: np.array array to be plotted - vmin: float - lower fraction cut off of pixel values + vmin: float + lower fraction cut off of pixel values vmax: float - upper fraction cut off of pixel values + upper fraction cut off of pixel values normalize: bool if True, rescales from 0 to 1