From 6db566e0f83bd1a75e112c95ba75232f28a7be30 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 3 Aug 2023 17:32:51 -0700 Subject: [PATCH 01/62] start for depth profile --- .../iterative_multislice_ptychography.py | 68 ++++++++++++++++++- py4DSTEM/process/phase/utils.py | 14 ++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 92f8c0bf3..cce65f6c3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -29,6 +29,7 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -974,7 +975,7 @@ def _gradient_descent_adjoint( ) # back-transmit - exit_waves *= xp.conj(obj) #/ xp.abs(obj) ** 2 + exit_waves *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -1076,7 +1077,7 @@ def _projection_sets_adjoint( ) # back-transmit - exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 + exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -2841,6 +2842,67 @@ def show_slices( spec.tight_layout(fig) + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + **kwargs, + ): + """ + doc strings go here + """ + ms_obj = self.object_cropped + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + fig, ax = plt.subplots() + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + def tune_num_slices_and_thicknesses( self, num_slices_guess=None, @@ -3067,4 +3129,4 @@ def _return_object_fft( obj = np.angle(obj) obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) \ No newline at end of file + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c2e1d3b77..118e9990a 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1610,3 +1610,17 @@ def fit_aberration_surface( fitted_angle = xp.tensordot(coeff, basis, axes=1) return fitted_angle, coeff + + +def rotate_point(origin, point, angle): + """ + Rotate a point counterclockwise by a given angle around a given origin. + + The angle should be given in radians. + """ + ox, oy = origin + px, py = point + + qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) + qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) + return qx, qy From 4eb461de2ffe9d245e64f8bcf96fb1361d625c12 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 7 Aug 2023 18:06:55 -0700 Subject: [PATCH 02/62] adding real-space kde upsampling --- py4DSTEM/process/phase/iterative_parallax.py | 224 ++++++++++++++++++- 1 file changed, 219 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 80cdd8cd8..c2bfc8739 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -14,6 +14,7 @@ from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.visualize import show from scipy.linalg import polar from scipy.special import comb @@ -246,6 +247,7 @@ def preprocess( ) if normalize_images: self._stack_BF = xp.ones(stack_shape) + self._stack_BF_no_window = xp.ones(stack_shape) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -259,6 +261,14 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + elif normalize_order == 1: x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) @@ -285,9 +295,18 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) + self._stack_BF_no_window[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + else: all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -299,6 +318,14 @@ def preprocess( + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) @@ -533,9 +560,9 @@ def tune_angle_and_defocus( divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) + fig.colorbar(im, cax=cax) - plt.tight_layout() + fig.tight_layout() if return_values: convergence = np.array(convergence).reshape( @@ -548,7 +575,7 @@ def reconstruct( max_alignment_bin: int = None, min_alignment_bin: int = 1, max_iter_at_min_bin: int = 2, - upsample_factor: int = 8, + cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, running_average: bool = True, @@ -570,7 +597,7 @@ def reconstruct( Minimum bin size for bright field alignment max_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size - upsample_factor: int, optional + cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional Bernstein basis degree used for regularizing shifts @@ -730,7 +757,7 @@ def reconstruct( xy_shift = align_images_fourier( G_ref, G, - upsample_factor=upsample_factor, + upsample_factor=cross_correlation_upsample_factor, device=self._device, ) @@ -837,6 +864,193 @@ def reconstruct( return self + def subpixel_alignment( + self, + kde_upsample_factor=4, + kde_sigma=0.125, + plot_upsampled_BF_comparison: bool = True, + plot_upsampled_FFT_comparison: bool = False, + **kwargs, + ): + """ + Upsample and subpixel-align BFs using the measured image shifts. + Uses kernel density estimation (KDE) to align upsampled BFs. + + Parameters + ---------- + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma: float, optional + KDE gaussian kernel bandwidth + plot_upsampled_BF_comparison: bool, optional + If True, the pre/post alignment BF images are plotted for comparison + plot_upsampled_FFT_comparison: bool, optional + If True, the pre/post alignment BF FFTs are plotted for comparison + + """ + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + xy_shifts = self._xy_shifts + BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + + pixel_output = BF_size * kde_upsample_factor + pixel_size = pixel_output.prod() + + # shifted coordinates + x = xp.arange(BF_size[0]) + y = xp.arange(BF_size[1]) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + xa = ((xa + xy_shifts[:, 0, None, None]) * kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * kde_upsample_factor).ravel() + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + sigma = kde_sigma * kde_upsample_factor + pix_count = gaussian_filter(pix_count, sigma) + pix_count[pix_output == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, sigma) + pix_output /= pix_count + + self._recon_BF_subpixel_aligned = pix_output + self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + + # plotting + if plot_upsampled_BF_comparison: + if plot_upsampled_FFT_comparison: + figsize = kwargs.pop("figsize", (8, 8)) + fig, axs = plt.subplots(2, 2, figsize=figsize) + else: + figsize = kwargs.pop("figsize", (8, 4)) + fig, axs = plt.subplots(1, 2, figsize=figsize) + + axs = axs.flat + cmap = kwargs.pop("cmap", "magma") + + cropped_object = self._crop_padded_object(self._recon_BF) + upsampled_pad_x = self._object_padding_px[0] * kde_upsample_factor // 2 + upsampled_pad_y = self._object_padding_px[1] * kde_upsample_factor // 2 + cropped_object_aligned = self.recon_BF_subpixel_aligned[ + upsampled_pad_x:-upsampled_pad_x, + upsampled_pad_y:-upsampled_pad_y, + ] + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + for ax in axs[:2]: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if plot_upsampled_FFT_comparison: + recon_fft = xp.fft.fft2(self._recon_BF) + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = BF_size[0] * (kde_upsample_factor - 1) // 2 + pad_y = BF_size[1] * (kde_upsample_factor - 1) // 2 + pad_recon_fft = asnumpy( + xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + + reciprocal_extent = [ + 0, + self._reciprocal_sampling[1] * cropped_object_aligned.shape[1], + self._reciprocal_sampling[0] * cropped_object_aligned.shape[0], + 0, + ] + + show( + pad_recon_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Aligned Bright Field FFT", + **kwargs, + ) + + show( + upsampled_fft, + figax=(fig, axs[3]), + extent=reciprocal_extent, + cmap="gray", + title="Upsampled Bright Field FFT", + **kwargs, + ) + + for ax in axs[2:]: + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.xaxis.set_ticks_position("bottom") + + fig.tight_layout() + def aberration_fit( self, plot_CTF_compare: bool = False, From 8c61a5b6591d410ea27377ce264eba839313c2e7 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 8 Aug 2023 11:58:24 -0700 Subject: [PATCH 03/62] more depth profile --- .../iterative_multislice_ptychography.py | 78 +++++++++++++++---- py4DSTEM/process/phase/utils.py | 16 +++- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index cce65f6c3..111f012ad 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2852,12 +2852,31 @@ def show_depth( ms_object=None, cbar: bool = False, aspect: float = None, + plot_line_profile: bool = False, **kwargs, ): """ - doc strings go here + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats + line profile for dpeth seciton runs from (x1,y1) to (x2,y2) + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken """ - ms_obj = self.object_cropped + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped angle = np.arctan((x2 - x1) / (y2 - y1)) x0 = ms_obj.shape[1] / 2 @@ -2879,6 +2898,7 @@ def show_depth( if gaussian_filter_sigma is not None: from scipy.ndimage import gaussian_filter + gaussian_filter_sigma /= self.sampling[0] rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] @@ -2890,18 +2910,48 @@ def show_depth( 0, ] - fig, ax = plt.subplots() - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("y [A]") - ax.set_ylabel("x [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + if plot_line_profile == False: + fig, ax = plt.subplots() + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[0] * ms_obj.shape[1], + self.sampling[1] * ms_obj.shape[2], + 0, + ] + fig, ax = plt.subplots(2, 1) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 / self.sampling[0], y2 / self.sampling[1]], + [x1 / self.sampling[0], x2 / self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("y [A]") + ax[1].set_ylabel("x [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) def tune_num_slices_and_thicknesses( self, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 118e9990a..a8a702f89 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1614,9 +1614,21 @@ def fit_aberration_surface( def rotate_point(origin, point, angle): """ - Rotate a point counterclockwise by a given angle around a given origin. + Rotate a point (x1, y1) counterclockwise by a given angle around + a given origin (x0, y0). + + Parameters + -------- + origin: 2-tuple of floats + (x0, y0) + point: 2-tuple of floats + (x1, y1) + angle: float (radians) + + Returns + -------- + rotated points (2-tuple) - The angle should be given in radians. """ ox, oy = origin px, py = point From 22d1bb489320cbcfcd0a089ea7ca2e384d542536 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 8 Aug 2023 12:01:38 -0700 Subject: [PATCH 04/62] saving error --- .../process/phase/iterative_multislice_ptychography.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 111f012ad..e000910c3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2860,8 +2860,8 @@ def show_depth( Parameters -------- - x1, x2, y1, y2: floats - line profile for dpeth seciton runs from (x1,y1) to (x2,y2) + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) gaussian_filter_sigma: float (optional) Standard deviation of gaussian kernel in A ms_object: np.array @@ -2933,8 +2933,8 @@ def show_depth( fig, ax = plt.subplots(2, 1) ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) ax[0].plot( - [y1 / self.sampling[0], y2 / self.sampling[1]], - [x1 / self.sampling[0], x2 / self.sampling[1]], + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], color="red", ) ax[0].set_xlabel("y [A]") @@ -2952,6 +2952,7 @@ def show_depth( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) + plt.tight_layout() def tune_num_slices_and_thicknesses( self, From 5f919bdfb26196e30409a2852b633d413ebd2e13 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 8 Aug 2023 12:04:18 -0700 Subject: [PATCH 05/62] small name changes --- .../process/phase/iterative_multislice_ptychography.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index e000910c3..438c9d1fb 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2915,8 +2915,8 @@ def show_depth( im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax.set_aspect(aspect) - ax.set_xlabel("y [A]") - ax.set_ylabel("x [A]") + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") ax.set_title("Multislice depth profile") if cbar: divider = make_axes_locatable(ax) @@ -2944,8 +2944,8 @@ def show_depth( im = ax[1].imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax[1].set_aspect(aspect) - ax[1].set_xlabel("y [A]") - ax[1].set_ylabel("x [A]") + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") ax[1].set_title("Multislice depth profile") if cbar: divider = make_axes_locatable(ax[1]) From 68b31a7c8dd0700c611eff6c56232558fe2e9514 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 16 Aug 2023 17:40:21 -0700 Subject: [PATCH 06/62] single slice tv denoise --- .../iterative_ptychographic_constraints.py | 81 ++++++++++++++++--- .../iterative_singleslice_ptychography.py | 12 +++ setup.py | 1 + 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 9af22ba92..6ae9e176d 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,4 +1,5 @@ import numpy as np +import pylops from py4DSTEM.process.phase.utils import ( array_slice, estimate_global_transformation_ransac, @@ -7,6 +8,7 @@ regularize_probe_amplitude, ) from py4DSTEM.process.utils import get_CoM +import warnings class PtychographicConstraints: @@ -183,6 +185,59 @@ def _object_butterworth_constraint( return current_object + def _object_denoise_tv_pylops(self, current_object, weight): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float, optional + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("tv_denoise currently for potential objects only"), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = 40 + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + def _object_denoise_tv_chambolle( self, current_object, @@ -363,8 +418,8 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - probe_intensity = xp.abs(current_probe) ** 2 - #current_probe_sum = xp.sum(probe_intensity) + # probe_intensity = xp.abs(current_probe) ** 2 + # current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -374,10 +429,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe #* normalization + return updated_probe # * normalization def _probe_fourier_amplitude_constraint( self, @@ -406,7 +461,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - #current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -419,10 +474,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe #* normalization + return updated_probe # * normalization def _probe_aperture_constraint( self, @@ -444,16 +499,16 @@ def _probe_aperture_constraint( """ xp = self._xp - #current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe #* normalization + return updated_probe # * normalization def _probe_aberration_fitting_constraint( self, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0480bae8a..0c9af9649 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1023,6 +1023,8 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, object_positivity, shrinkage_rad, object_mask, @@ -1108,6 +1110,12 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weight, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1198,6 +1206,8 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1618,6 +1628,8 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse diff --git a/setup.py b/setup.py index b0c7fa081..d8baff354 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ 'dask >= 2.3.0', 'distributed >= 2.3.0', 'emdfile >= 0.0.10', + 'pylops >= 2.1.0' ], extras_require={ 'ipyparallel': ['ipyparallel >= 6.2.4', 'dill >= 0.3.3'], From 5acdd5809e51910c6f1a4242248eb4f12c4be5a7 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 17 Aug 2023 14:56:10 -0700 Subject: [PATCH 07/62] multislice tv denoise... will test more before adding to other classes... --- .../iterative_multislice_ptychography.py | 119 +++++++++++-- .../iterative_ptychographic_constraints.py | 163 ++++++++++-------- .../iterative_singleslice_ptychography.py | 8 + 3 files changed, 197 insertions(+), 93 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 438c9d1fb..5966ca07a 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex @@ -1450,6 +1451,70 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_denoise_tv_pylops(self, current_object, weights): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("tv_denoise currently for potential objects only"), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = 40 + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + # remove padding + + return current_object_tv[1:-1] + def _constraints( self, current_object, @@ -1482,9 +1547,11 @@ def _constraints( shrinkage_rad, object_mask, pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, tv_denoise, - tv_denoise_weight, - tv_denoise_pad, + tv_denoise_weights, ): """ Ptychographic constraints operator. @@ -1549,12 +1616,17 @@ def _constraints( If not None, used to calculate additional shrinkage using masked-mean of object pure_phase_object: bool If True, object amplitude is set to unity - tv_denoise: bool + tv_denoise_chambolle: bool If True, performs TV denoising along z - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. Returns -------- @@ -1586,13 +1658,16 @@ def _constraints( current_object, kz_regularization_gamma ) elif tv_denoise: - if self._object_type == "complex": - raise NotImplementedError() + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + ) + elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, - tv_denoise_weight, + tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad, + pad_object=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1691,9 +1766,11 @@ def reconstruct( shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, - tv_denoise_weight=None, - tv_denoise_pad=True, + tv_denoise_weights=None, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1786,12 +1863,17 @@ def reconstruct( If true, the potential mean outside the FOV is forced to zero at each iteration pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - tv_denoise_iter: bool + tv_denoise_iter_chambolle: bool Number of iterations with TV denoisining - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -2134,9 +2216,12 @@ def reconstruct( else None, pure_phase_object=a0 < pure_phase_object_iter and self._object_type == "complex", - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_pad=tv_denoise_pad, + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 6ae9e176d..e300e1154 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -193,7 +193,7 @@ def _object_denoise_tv_pylops(self, current_object, weight): ---------- current_object: np.ndarray Current object estimate - weight : float, optional + weight : float Denoising weight. The greater `weight`, the more denoising (at the expense of fidelity to `input`). Returns @@ -284,90 +284,101 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) + if xp.iscomplexobj(current_object): + updated_object = current_object + warnings.warn( + ("tv_denoise currently for potential objects only"), + UserWarning, + ) else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" + + current_object_sum = xp.sum(current_object) + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if pad_object: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (1, 1) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() - p = xp.zeros( - (current_object.ndim,) + current_object.shape, dtype=current_object.dtype - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ slice(None), ] * (current_object.ndim + 1) for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E E_previous = E - i += 1 + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] + if pad_object: + for ax in range(len(ndim)): + slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + updated_object = updated_object[slices] + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) - return updated_object / xp.sum(updated_object) * current_object_sum + return updated_object def _probe_center_of_mass_constraint(self, current_probe): """ diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0c9af9649..2480974f3 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1080,6 +1080,10 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1294,6 +1298,10 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float From aac2d12a3c7d17ca471be504c828119ab3a127a7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 21 Aug 2023 08:43:11 -0700 Subject: [PATCH 08/62] minor tv bugfix --- .../process/phase/iterative_multislice_ptychography.py | 8 ++++---- .../process/phase/iterative_ptychographic_constraints.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 5966ca07a..307095960 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1474,7 +1474,7 @@ def _object_denoise_tv_pylops(self, current_object, weights): if xp.iscomplexobj(current_object): current_object_tv = current_object warnings.warn( - ("tv_denoise currently for potential objects only"), + ("TV denoising is currently only supported for potential objects."), UserWarning, ) @@ -1484,6 +1484,7 @@ def _object_denoise_tv_pylops(self, current_object, weights): current_object = xp.pad( current_object, pad_width=pad_width, mode="constant" ) + # run tv denoising nz, nx, ny = current_object.shape niter_out = 40 @@ -1509,11 +1510,10 @@ def _object_denoise_tv_pylops(self, current_object, weights): show=False, )[0] - current_object_tv = current_object_tv.reshape(current_object.shape) - # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - return current_object_tv[1:-1] + return current_object_tv def _constraints( self, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index e300e1154..95c2a9531 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -207,7 +207,7 @@ def _object_denoise_tv_pylops(self, current_object, weight): if xp.iscomplexobj(current_object): current_object_tv = current_object warnings.warn( - ("tv_denoise currently for potential objects only"), + ("TV denoising is currently only supported for potential objects."), UserWarning, ) @@ -287,11 +287,10 @@ def _object_denoise_tv_chambolle( if xp.iscomplexobj(current_object): updated_object = current_object warnings.warn( - ("tv_denoise currently for potential objects only"), + ("TV denoising is currently only supported for potential objects."), UserWarning, ) else: - current_object_sum = xp.sum(current_object) if axis is None: ndim = xp.arange(current_object.ndim).tolist() From da85db0da8524d935d4caa2a798c7f9e0c2cb079 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 21 Aug 2023 09:53:13 -0700 Subject: [PATCH 09/62] improvements for depth plotting --- .../iterative_multislice_ptychography.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 307095960..ddedac229 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2933,6 +2933,7 @@ def show_depth( x2: float, y1: float, y2: float, + specify_calibrated: bool = False, gaussian_filter_sigma: float = None, ms_object=None, cbar: bool = False, @@ -2947,6 +2948,9 @@ def show_depth( -------- x1, x2, y1, y2: floats (pixels) Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels gaussian_filter_sigma: float (optional) Standard deviation of gaussian kernel in A ms_object: np.array @@ -2962,11 +2966,31 @@ def show_depth( ms_obj = ms_object else: ms_obj = self.object_cropped - angle = np.arctan((x2 - x1) / (y2 - y1)) + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) x0 = ms_obj.shape[1] / 2 y0 = ms_obj.shape[2] / 2 + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + from py4DSTEM.process.phase.utils import rotate_point x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) From a5822129726f28992844f98e94cbcb45b79d9786 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 21 Aug 2023 19:26:26 -0700 Subject: [PATCH 10/62] tv fixes --- .../iterative_multislice_ptychography.py | 91 ++++++++++++++----- .../iterative_ptychographic_constraints.py | 8 +- .../iterative_singleslice_ptychography.py | 14 ++- 3 files changed, 86 insertions(+), 27 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index ddedac229..d9bf6bacf 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1451,7 +1451,7 @@ def _object_identical_slices_constraint(self, current_object): return current_object - def _object_denoise_tv_pylops(self, current_object, weights): + def _object_denoise_tv_pylops(self, current_object, weights, iterations): """ Performs second order TV denoising along x and y @@ -1462,6 +1462,9 @@ def _object_denoise_tv_pylops(self, current_object, weights): weights : [float, float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops Returns ------- @@ -1487,28 +1490,66 @@ def _object_denoise_tv_pylops(self, current_object, weights): # run tv denoising nz, nx, ny = current_object.shape - niter_out = 40 + niter_out = iterations niter_in = 1 Iop = pylops.Identity(nx * ny * nz) - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] # remove padding current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] @@ -1552,6 +1593,7 @@ def _constraints( tv_denoise_pad_chambolle, tv_denoise, tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1627,6 +1669,8 @@ def _constraints( tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1661,6 +1705,7 @@ def _constraints( current_object = self._object_denoise_tv_pylops( current_object, tv_denoise_weights, + tv_denoise_inner_iter, ) elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( @@ -1771,6 +1816,7 @@ def reconstruct( tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, tv_denoise_weights=None, + tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1874,6 +1920,8 @@ def reconstruct( tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -2222,6 +2270,7 @@ def reconstruct( tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 95c2a9531..217253945 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -185,7 +185,7 @@ def _object_butterworth_constraint( return current_object - def _object_denoise_tv_pylops(self, current_object, weight): + def _object_denoise_tv_pylops(self, current_object, weight, iterations): """ Performs second order TV denoising along x and y @@ -196,6 +196,10 @@ def _object_denoise_tv_pylops(self, current_object, weight): weight : float Denoising weight. The greater `weight`, the more denoising (at the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + Returns ------- constrained_object: np.ndarray @@ -213,7 +217,7 @@ def _object_denoise_tv_pylops(self, current_object, weight): else: nx, ny = current_object.shape - niter_out = 40 + niter_out = iterations niter_in = 1 Iop = pylops.Identity(nx * ny) xy_laplacian = pylops.Laplacian( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 2480974f3..97c7a3e5d 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1025,6 +1025,7 @@ def _constraints( butterworth_order, tv_denoise, tv_denoise_weight, + tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, @@ -1082,8 +1083,10 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter tv_denoise: bool If True, applies TV denoising on object - tv_denoise_weight: float + tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1116,8 +1119,7 @@ def _constraints( if tv_denoise: current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weight, + current_object, tv_denoise_weight, tv_denoise_inner_iter ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1212,6 +1214,7 @@ def reconstruct( butterworth_order: float = 2, tv_denoise_iter: int = np.inf, tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1300,8 +1303,10 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter tv_denoise_iter: int, optional Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float + tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1638,6 +1643,7 @@ def reconstruct( butterworth_order=butterworth_order, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse From eaa3699a9a8c00a48bb18a6e8e8efe6f595f2397 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 24 Aug 2023 17:22:14 -0700 Subject: [PATCH 11/62] everyone gets TV denoising --- .../iterative_mixedstate_ptychography.py | 26 +++ .../iterative_overlap_magnetic_tomography.py | 168 +++++++++++++++++- .../phase/iterative_overlap_tomography.py | 147 ++++++++++++++- .../iterative_simultaneous_ptychography.py | 30 ++++ 4 files changed, 361 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 56fec1004..d066c7f3f 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1125,6 +1125,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, orthogonalize_probe, object_positivity, shrinkage_rad, @@ -1183,6 +1186,12 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter orthogonalize_probe: bool If True, probe will be orthogonalized + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1213,6 +1222,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1290,6 +1304,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1373,6 +1390,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1707,6 +1730,9 @@ def reconstruct( q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 8691a121d..2642b7193 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize import show @@ -1679,6 +1680,111 @@ def _divergence_free_constraint(self, vector_field): return vector_field + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1710,6 +1816,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1771,6 +1880,15 @@ def _constraints( If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1822,6 +1940,31 @@ def _constraints( butterworth_order, ) + elif tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[1] = self._object_denoise_tv_pylops( + current_object[1], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[2] = self._object_denoise_tv_pylops( + current_object[2], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[3] = self._object_denoise_tv_pylops( + current_object[3], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object[0] = self._object_shrinkage_constraint( current_object[0], @@ -1913,6 +2056,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1998,6 +2144,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2477,6 +2632,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2487,11 +2646,7 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - ( - self._object, - self._probe, - _, - ) = self._constraints( + (self._object, self._probe, _,) = self._constraints( self._object, self._probe, None, @@ -2530,6 +2685,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + v_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..9cfec2b39 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize import show @@ -1527,6 +1528,111 @@ def _object_butterworth_constraint( current_object += current_object_mean return xp.real(current_object) + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1555,6 +1661,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1611,6 +1720,13 @@ def _constraints( Phase shift in radians to be subtracted from the potential at each iteration object_mask: np.ndarray (boolean) If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1634,6 +1750,12 @@ def _constraints( q_highpass, butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( @@ -1723,6 +1845,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1806,6 +1931,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2203,6 +2337,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2211,11 +2349,7 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - ( - self._object, - self._probe, - _, - ) = self._constraints( + (self._object, self._probe, _,) = self._constraints( self._object, self._probe, None, @@ -2251,6 +2385,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + v_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8881d021c..a19fc82d3 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -2232,6 +2232,9 @@ def _constraints( q_highpass_e, q_highpass_m, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, warmup_iteration, object_positivity, shrinkage_rad, @@ -2300,6 +2303,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising warmup_iteration: bool If True, constraints electrostatic object only object_positivity: bool @@ -2349,6 +2358,15 @@ def _constraints( if self._object_type == "complex": magnetic_obj = magnetic_obj.real + if tv_denoise: + electrostatic_obj = self._object_denoise_tv_pylops( + electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) + + if not warmup_iteration: + magnetic_obj = self._object_denoise_tv_pylops( + magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) if shrinkage_rad > 0.0 or object_mask is not None: electrostatic_obj = self._object_shrinkage_constraint( @@ -2446,6 +2464,9 @@ def reconstruct( q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -2538,6 +2559,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -2899,6 +2926,9 @@ def reconstruct( q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse From cc364907faf66d6351e10861c5991ad492c8b380 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 29 Aug 2023 13:30:43 -0700 Subject: [PATCH 12/62] subpixel alignment phase correct, part 1 --- py4DSTEM/process/phase/iterative_parallax.py | 151 ++++++++++++++++++- 1 file changed, 143 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index c2bfc8739..9fe1fbc90 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -895,7 +895,8 @@ def subpixel_alignment( xy_shifts = self._xy_shifts BF_size = np.array(self._stack_BF_no_window.shape[-2:]) - pixel_output = BF_size * kde_upsample_factor + self._kde_upsample_factor = kde_upsample_factor + pixel_output = BF_size * self._kde_upsample_factor pixel_size = pixel_output.prod() # shifted coordinates @@ -903,8 +904,8 @@ def subpixel_alignment( y = xp.arange(BF_size[1]) xa, ya = xp.meshgrid(x, y, indexing="ij") - xa = ((xa + xy_shifts[:, 0, None, None]) * kde_upsample_factor).ravel() - ya = ((ya + xy_shifts[:, 1, None, None]) * kde_upsample_factor).ravel() + xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() # bilinear sampling xF = xp.floor(xa).astype("int") @@ -948,7 +949,7 @@ def subpixel_alignment( ) # kernel density estimate - sigma = kde_sigma * kde_upsample_factor + sigma = kde_sigma * self._kde_upsample_factor pix_count = gaussian_filter(pix_count, sigma) pix_count[pix_output == 0.0] = np.inf pix_output = gaussian_filter(pix_output, sigma) @@ -970,8 +971,12 @@ def subpixel_alignment( cmap = kwargs.pop("cmap", "magma") cropped_object = self._crop_padded_object(self._recon_BF) - upsampled_pad_x = self._object_padding_px[0] * kde_upsample_factor // 2 - upsampled_pad_y = self._object_padding_px[1] * kde_upsample_factor // 2 + upsampled_pad_x = ( + self._object_padding_px[0] * self._kde_upsample_factor // 2 + ) + upsampled_pad_y = ( + self._object_padding_px[1] * self._kde_upsample_factor // 2 + ) cropped_object_aligned = self.recon_BF_subpixel_aligned[ upsampled_pad_x:-upsampled_pad_x, upsampled_pad_y:-upsampled_pad_y, @@ -1007,8 +1012,8 @@ def subpixel_alignment( if plot_upsampled_FFT_comparison: recon_fft = xp.fft.fft2(self._recon_BF) recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) - pad_x = BF_size[0] * (kde_upsample_factor - 1) // 2 - pad_y = BF_size[1] * (kde_upsample_factor - 1) // 2 + pad_x = BF_size[0] * (self._kde_upsample_factor - 1) // 2 + pad_y = BF_size[1] * (self._kde_upsample_factor - 1) // 2 pad_recon_fft = asnumpy( xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) ) @@ -1318,6 +1323,136 @@ def aberration_correct( ax.set_xlabel("y [A]") ax.set_title("Parallax-Corrected Phase Image") + def subpixel_aberration_correct( + self, + plot_corrected_phase: bool = True, + k_info_limit: float = None, + k_info_power: float = 1.0, + Wiener_filter=False, + Wiener_signal_noise_ratio=1.0, + Wiener_filter_low_only=False, + **kwargs, + ): + """ + CTF correction of the phase image using the measured defocus aberration. + + Parameters + ---------- + plot_corrected_phase: bool, optional + If True, the CTF-corrected phase is plotted + k_info_limit: float, optional + maximum allowed frequency in butterworth filter + k_info_power: float, optional + power of butterworth filter + Wiener_filter: bool, optional + Use Wiener filtering instead of CTF sign correction. + Wiener_signal_noise_ratio: float, optional + Signal to noise radio at k = 0 for Wiener filter + Wiener_filter_low_only: bool, optional + Apply Wiener filtering only to the CTF portions before the 1st CTF maxima. + """ + + xp = self._xp + asnumpy = self._asnumpy + + if not hasattr(self, "aberration_C1"): + raise ValueError( + ( + "CTF correction is meant to be ran after alignment and aberration fitting. " + "Please run the `reconstruct()` and `aberration_fit()` functions first." + ) + ) + + # Fourier coordinates + kx = xp.fft.fftfreq( + self._recon_BF_subpixel_aligned.shape[0], + self._scan_sampling[0] / self._kde_upsample_factor, + ) + ky = xp.fft.fftfreq( + self._recon_BF_subpixel_aligned.shape[1], + self._scan_sampling[1] / self._kde_upsample_factor, + ) + kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 + + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr + print(self._recon_BF_subpixel_aligned.shape) + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) + + # Output phase image + self._recon_phase_corrected_subpixel_aligned = xp.real( + xp.fft.ifft2(im_fft_corr) + ) + self.recon_phase_corrected_subpixel_aligned = asnumpy( + self._recon_phase_corrected_subpixel_aligned + ) + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + # plotting + if plot_corrected_phase: + figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "magma") + + fig, ax = plt.subplots(figsize=figsize) + + cropped_object = self._crop_padded_object(self._recon_phase_corrected) + + extent = [ + 0, + self._scan_sampling[1] + / self._kde_upsample_factor + * cropped_object.shape[1], + self._scan_sampling[0] + / self._kde_upsample_factor + * cropped_object.shape[0], + 0, + ] + + ax.imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Parallax-Corrected Phase Image Subpixel Aligned") + def depth_section( self, depth_angstroms=np.arange(-250, 260, 100), From 41f319d09720e2cbc2fcfb8d04dd70f351905789 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 29 Aug 2023 13:31:53 -0700 Subject: [PATCH 13/62] removing print statement --- py4DSTEM/process/phase/iterative_parallax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 9fe1fbc90..36d90a6f0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1403,7 +1403,6 @@ def subpixel_aberration_correct( # apply correction to mean reconstructed BF image im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr - print(self._recon_BF_subpixel_aligned.shape) # if needed, add low pass filter output image if k_info_limit is not None: im_fft_corr /= 1 + (kra2**k_info_power) / ( From 95beba51cfe84ba8a85614b5a7fc7f20dfd9222f Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 7 Sep 2023 05:38:41 -0700 Subject: [PATCH 14/62] parallax plotting fix --- py4DSTEM/process/phase/iterative_parallax.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 36d90a6f0..d8824b770 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1428,7 +1428,9 @@ def subpixel_aberration_correct( fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_phase_corrected) + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) extent = [ 0, @@ -1594,6 +1596,7 @@ def _crop_padded_object( self, padded_object: np.ndarray, remaining_padding: int = 0, + upsampled: bool = False, ): """ Utility function to crop padded object @@ -1617,6 +1620,10 @@ def _crop_padded_object( pad_x = self._object_padding_px[0] // 2 - remaining_padding pad_y = self._object_padding_px[1] // 2 - remaining_padding + if upsampled == True: + pad_x *= self._kde_upsample_factor + pad_y *= self._kde_upsample_factor + return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) def _visualize_figax( From c25e24264249da686a3ad8d14f01cb098aee25ea Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 08:58:06 -0700 Subject: [PATCH 15/62] generalizing overlap tomo to orientation matrices --- .../phase/iterative_overlap_tomography.py | 107 +++++++++++------- 1 file changed, 66 insertions(+), 41 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..98bfb7b5f 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -55,8 +55,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): The electron energy of the wave functions in eV num_slices: int Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of tilt angles in degrees, + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -94,13 +94,18 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _swap_zxy_to_xyz = np.array([ + [0,1,0], + [0,0,1], + [1,0,0] + ]) def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[float], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -122,22 +127,24 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from scipy.special import erf self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from cupyx.scipy.special import erf self._erf = erf @@ -156,7 +163,7 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts @@ -185,7 +192,7 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts def _precompute_propagator_arrays( @@ -323,6 +330,30 @@ def _expand_sliced_object(self, array: np.ndarray, output_z): normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + ): + """ + """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T@rot_matrix.T@swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume,tf,offset=offset,order=3) + + return volume + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -663,15 +694,15 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - current_angle_deg = self._tilt_angles_deg[tilt_index] - probe_overlap_3D = self._rotate( + + rot_matrix = self._tilt_orientation_matrices[tilt_index] + + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, + rot_matrix@old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -691,14 +722,12 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] - - probe_overlap_3D = self._rotate( - probe_overlap_3D, - -current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, - ) + old_rot_matrix = rot_matrix + + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( @@ -2018,17 +2047,17 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index tilt_error = 0.0 - self._object = self._rotate( + rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + self._object = self._rotate_zxy_volume( self._object, - self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix @ old_rot_matrix.T, ) object_sliced = self._project_sliced_object( @@ -2132,24 +2161,15 @@ def reconstruct( ) if collective_tilt_updates: - collective_object += self._rotate( + collective_object += self._rotate_zxy_volume( object_update, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix.T ) else: self._object += object_update - - self._object = self._rotate( - self._object, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, - ) - + + old_rot_matrix = rot_matrix + # Normalize Error tilt_error /= ( self._mean_diffraction_intensity[self._active_tilt_index] @@ -2205,6 +2225,11 @@ def reconstruct( else None, ) + self._object = self._rotate_zxy_volume( + self._object, + old_rot_matrix.T + ) + # Normalize Error Over Tilts error /= self._num_tilts From 31b5525b25bb7c07b15a18e20c918b4edb7ec8d5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 09:03:45 -0700 Subject: [PATCH 16/62] black formatting --- .../iterative_overlap_magnetic_tomography.py | 6 +- .../phase/iterative_overlap_tomography.py | 61 +++++++++---------- .../iterative_ptychographic_constraints.py | 2 +- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 2642b7193..712a35647 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -2646,7 +2646,11 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - (self._object, self._probe, _,) = self._constraints( + ( + self._object, + self._probe, + _, + ) = self._constraints( self._object, self._probe, None, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index e8cde608e..110a547b6 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -96,11 +96,7 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): # Class-specific Metadata _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") - _swap_zxy_to_xyz = np.array([ - [0,1,0], - [0,0,1], - [1,0,0] - ]) + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, @@ -140,7 +136,12 @@ def __init__( elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform + from cupyx.scipy.ndimage import ( + gaussian_filter, + rotate, + zoom, + affine_transform, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom @@ -335,24 +336,23 @@ def _rotate_zxy_volume( self, volume_array, rot_matrix, - ): - """ - """ - + ): + """ """ + xp = self._xp affine_transform = self._affine_transform swap_zxy_to_xyz = self._swap_zxy_to_xyz - + volume = volume_array.copy() volume_shape = xp.asarray(volume.shape) - tf = xp.asarray(swap_zxy_to_xyz.T@rot_matrix.T@swap_zxy_to_xyz) - + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + in_center = (volume_shape - 1) / 2 out_center = tf @ in_center offset = in_center - out_center - - volume = affine_transform(volume,tf,offset=offset,order=3) - + + volume = affine_transform(volume, tf, offset=offset, order=3) + return volume def preprocess( @@ -695,15 +695,14 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) - old_rot_matrix = np.eye(3) # identity + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - rot_matrix = self._tilt_orientation_matrices[tilt_index] probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - rot_matrix@old_rot_matrix.T, + rot_matrix @ old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -724,7 +723,7 @@ def preprocess( probe_overlap_3D += probe_overlap[None] old_rot_matrix = rot_matrix - + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, old_rot_matrix.T, @@ -2181,8 +2180,8 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) - old_rot_matrix = np.eye(3) # identity - + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index @@ -2296,14 +2295,13 @@ def reconstruct( if collective_tilt_updates: collective_object += self._rotate_zxy_volume( - object_update, - rot_matrix.T + object_update, rot_matrix.T ) else: self._object += object_update - + old_rot_matrix = rot_matrix - + # Normalize Error tilt_error /= ( self._mean_diffraction_intensity[self._active_tilt_index] @@ -2363,10 +2361,7 @@ def reconstruct( tv_denoise_inner_iter=tv_denoise_inner_iter, ) - self._object = self._rotate_zxy_volume( - self._object, - old_rot_matrix.T - ) + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) # Normalize Error Over Tilts error /= self._num_tilts @@ -2374,7 +2369,11 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - (self._object, self._probe, _,) = self._constraints( + ( + self._object, + self._probe, + _, + ) = self._constraints( self._object, self._probe, None, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 217253945..ba9f28332 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -197,7 +197,7 @@ def _object_denoise_tv_pylops(self, current_object, weight, iterations): Denoising weight. The greater `weight`, the more denoising (at the expense of fidelity to `input`). iterations: float - Number of iterations to run in denoising algorithm. + Number of iterations to run in denoising algorithm. `niter_out` in pylops Returns From d4363f4eb081b3deb54ce559d3987caf7967782f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 09:30:35 -0700 Subject: [PATCH 17/62] flake8 6.1.0 found some more issues --- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 4 ++-- py4DSTEM/process/phase/iterative_overlap_tomography.py | 4 ++-- py4DSTEM/process/phase/iterative_parallax.py | 2 +- py4DSTEM/process/phase/iterative_ptychographic_constraints.py | 3 ++- py4DSTEM/process/phase/utils.py | 4 ++-- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index d9bf6bacf..3e36978e4 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2254,7 +2254,7 @@ def reconstruct( and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None - and type(kz_regularization_gamma) == np.ndarray + and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, @@ -3068,7 +3068,7 @@ def show_depth( 0, ] - if plot_line_profile == False: + if not plot_line_profile: fig, ax = plt.subplots() im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 110a547b6..fd9af0bb2 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -124,7 +124,7 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom @@ -137,10 +137,10 @@ def __init__( self._xp = cp self._asnumpy = cp.asnumpy from cupyx.scipy.ndimage import ( + affine_transform, gaussian_filter, rotate, zoom, - affine_transform, ) self._gaussian_filter = gaussian_filter diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index d8824b770..b23fe2cae 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1620,7 +1620,7 @@ def _crop_padded_object( pad_x = self._object_padding_px[0] // 2 - remaining_padding pad_y = self._object_padding_px[1] // 2 - remaining_padding - if upsampled == True: + if upsampled: pad_x *= self._kde_upsample_factor pad_y *= self._kde_upsample_factor diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index ba9f28332..4721ed12b 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pylops from py4DSTEM.process.phase.utils import ( @@ -8,7 +10,6 @@ regularize_probe_amplitude, ) from py4DSTEM.process.utils import get_CoM -import warnings class PtychographicConstraints: diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a8a702f89..d06db111c 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1614,9 +1614,9 @@ def fit_aberration_surface( def rotate_point(origin, point, angle): """ - Rotate a point (x1, y1) counterclockwise by a given angle around + Rotate a point (x1, y1) counterclockwise by a given angle around a given origin (x0, y0). - + Parameters -------- origin: 2-tuple of floats From e60071a3859283667cefeaad9b9f95495928418a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 14:28:18 -0700 Subject: [PATCH 18/62] adding mixed-state multi-slice ptycho class --- py4DSTEM/process/phase/__init__.py | 30 +- ...tive_mixedstate_multislice_ptychography.py | 3513 +++++++++++++++++ 2 files changed, 3521 insertions(+), 22 deletions(-) create mode 100644 py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 178079349..1005a619d 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,28 +3,14 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import ( - MixedstatePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_multislice_ptychography import ( - MultislicePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import ( - OverlapMagneticTomographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_tomography import ( - OverlapTomographicReconstruction, -) +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import ( - SimultaneousPtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( - SingleslicePtychographicReconstruction, -) -from py4DSTEM.process.phase.parameter_optimize import ( - OptimizationParameter, - PtychographyOptimizer, -) +from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..acb9f12a2 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -0,0 +1,3513 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pylops +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex + +try: + import cupy as cp +except ImportError: + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate + +warnings.simplefilter(action="always", category=UserWarning) + + +class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._verbose = verbose + self._device = device + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + probe_roi_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + probe_roi_shape, (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._probe_roi_shape = probe_roi_shape + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + self._intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + self.com_x, + self.com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + ( + self._amplitudes, + self._mean_diffraction_intensity, + ) = self._normalize_diffraction_intensities( + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + ) + + # explicitly delete namespace + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + del self._intensities + + self._positions_px = self._calculate_scan_positions_in_pixels( + self._scan_positions + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + if self._object is None: + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) + p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( + "int" + ) + if self._object_type == "potential": + self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) + else: + if self._object_type == "potential": + self._object = xp.asarray(self._object, dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.asarray(self._object, dtype=xp.complex64) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # Vectorized Patches + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Probe Initialization + if self._probe is None or isinstance(self._probe, ComplexProbe): + if self._probe is None: + if self._vacuum_probe_intensity is not None: + self._semiangle_cutoff = np.inf + self._vacuum_probe_intensity = xp.asarray( + self._vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + self._vacuum_probe_intensity, + device=self._device, + ) + self._vacuum_probe_intensity = get_shifted_ar( + self._vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=self._device, + ) + + _probe = ( + ComplexProbe( + gpts=self._region_of_interest_shape, + sampling=self.sampling, + energy=self._energy, + semiangle_cutoff=self._semiangle_cutoff, + rolloff=self._rolloff, + vacuum_probe_intensity=self._vacuum_probe_intensity, + parameters=self._polar_parameters, + device=self._device, + ) + .build() + ._array + ) + + else: + if self._probe._gpts != self._region_of_interest_shape: + raise ValueError() + if hasattr(self._probe, "_array"): + _probe = self._probe._array + else: + self._probe._xp = xp + _probe = self._probe.build()._array + + self._probe = xp.zeros( + (self._num_probes,) + tuple(self._region_of_interest_shape), + dtype=xp.complex64, + ) + sx, sy = self._region_of_interest_shape + self._probe[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, self._num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + self._probe[i_probe] = ( + self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) + self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) + + else: + self._probe = xp.asarray(self._probe, dtype=xp.complex64) + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # overlaps + shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) + probe_intensities = xp.abs(shifted_probes) ** 2 + probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + + if object_fov_mask is None: + self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + cmap = kwargs.pop("cmap", "Greys_r") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + hue_start = kwargs.pop("hue_start", 0) + invert = kwargs.pop("invert", False) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + vmin=vmin, + vmax=vmax, + hue_start=hue_start, + invert=invert, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + vmin=vmin, + vmax=vmax, + hue_start=hue_start, + invert=invert, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + **kwargs, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial Probe[0]") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + **kwargs, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated Probe[0]") + + ax3.imshow( + asnumpy(probe_overlap), + extent=extent, + cmap=cmap, + **kwargs, + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object Field of View") + + fig.tight_layout() + + self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + propagated_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) + propagated_probes[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes = ( + xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes[s + 1] = self._propagate_array( + transmitted_probes, self._propagator_arrays[s] + ) + + return propagated_probes, object_patches, transmitted_probes + + def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + + Returns + -------- + exit_waves:np.ndarray + Exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves + modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_exit_wave - transmitted_probes + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = transmitted_probes.copy() + + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = ( + projection_c * transmitted_probes + projection_y * exit_waves + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * transmitted_probes + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + ( + propagated_probes, + object_patches, + transmitted_probes, + ) = self._overlap_projection(current_object, current_probe) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, transmitted_probes + ) + + return propagated_probes, object_patches, transmitted_probes, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + current_probe, + transmitted_probes, + amplitudes, + current_positions, + positions_step_size, + constrain_position_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe:np.ndarray + fractionally-shifted probes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + constrain_position_distance: float + Distance to constrain position correction within original + field of view in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # Intensity gradient + exit_waves_fft = xp.fft.fft2(transmitted_probes) + exit_waves_fft_conj = xp.conj(exit_waves_fft) + estimated_intensity = xp.abs(exit_waves_fft) ** 2 + measured_intensity = amplitudes**2 + + flat_shape = (transmitted_probes.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # Computing perturbed exit waves one at a time to save on memory + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + # dx + obj_rolled_patches = complex_object[ + :, + (self._vectorized_patch_indices_row + 1) % self._object_shape[0], + self._vectorized_patch_indices_col, + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + # dy + obj_rolled_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + (self._vectorized_patch_indices_col + 1) % self._object_shape[1], + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + partial_intensity_dx = 2 * xp.real( + exit_waves_dx_fft * exit_waves_fft_conj + ).reshape(flat_shape) + partial_intensity_dy = 2 * xp.real( + exit_waves_dy_fft * exit_waves_fft_conj + ).reshape(flat_shape) + + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + + # positions_update = xp.einsum( + # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity + # ) + + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + if constrain_position_distance is not None: + constrain_position_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + x1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 0 + ] + y1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 1 + ] + x0 = self._positions_px_initial[:, 0] + y0 = self._positions_px_initial[:, 1] + if self._rotation_best_transpose: + x0, y0 = xp.array([y0, x0]) + x1, y1 = xp.array([y1, x1]) + + if self._rotation_best_rad is not None: + rotation_angle = self._rotation_best_rad + x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( + -rotation_angle + ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) + x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( + -rotation_angle + ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) + + outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( + x1 < (xp.min(x0) - constrain_position_distance) + ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( + y1 < (xp.min(y0) - constrain_position_distance) + ) > 0 + + positions_update[..., 0][outlier_ind] = 0 + + current_positions -= positions_step_size * positions_update[..., 0] + + return current_positions + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + 2D Butterworth filter + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + current_object = xp.pad( + current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" + ) + + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[1:] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + + def _constraints( + self, + current_object, + current_probe, + current_positions, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, + fix_positions, + global_affine_transformation, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + q_lowpass, + q_highpass, + butterworth_order, + kz_regularization_filter, + kz_regularization_gamma, + identical_slices, + object_positivity, + shrinkage_rad, + object_mask, + pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + orthogonalize_probe, + ): + """ + Ptychographic constraints operator. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + current_positions: np.ndarray + Current positions estimate + fix_com: bool + If True, probe CoM is fixed to the center + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool + If True, probe amplitude is constrained by top hat function + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude + fix_positions: bool + If True, positions are not updated + gaussian_filter: bool + If True, applies real-space gaussian filter in A + gaussian_filter_sigma: float + Standard deviation of gaussian kernel + butterworth_filter: bool + If True, applies fourier-space butterworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool + If True, applies fourier-space arctan regularization filter + kz_regularization_gamma: float + Slice regularization strength + identical_slices: bool + If True, forces all object slices to be identical + object_positivity: bool + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + pure_phase_object: bool + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True, performs TV denoising along z + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + orthogonalize_probe: bool + If True, probe will be orthogonalized + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + constrained_probe: np.ndarray + Constrained probe estimate + constrained_positions: np.ndarray + Constrained positions estimate + """ + + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, kz_regularization_gamma + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + pad_object=tv_denoise_pad_chambolle, + ) + + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # These constraints don't _really_ make sense for mixed-state + if fix_probe_aperture: + raise NotImplementedError() + elif constrain_probe_fourier_amplitude: + raise NotImplementedError() + if fit_probe_aberrations: + raise NotImplementedError() + if constrain_probe_amplitude: + raise NotImplementedError() + + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + max_iter: int = 64, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe_iter: int = 0, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions_iter: int = np.inf, + constrain_position_distance: float = None, + global_affine_transformation: bool = True, + gaussian_filter_sigma: float = None, + gaussian_filter_iter: int = np.inf, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + butterworth_filter_iter: int = np.inf, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter_iter: int = np.inf, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices_iter: int = 0, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + switch_object_iter: int = np.inf, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + max_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe_iter: int, optional + Number of iterations to run with a fixed probe before updating probe estimate + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions_iter: int, optional + Number of iterations to run with fixed positions before updating positions estimate + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter_iter: int, optional + Number of iterations to run using object smoothness constraint + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + butterworth_filter_iter: int, optional + Number of iterations to run using high-pass butteworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter_iter: int, optional + Number of iterations to run using kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices_iter: int, optional + Number of iterations to run using identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object_iter: int, optional + Number of iterations where object amplitude is set to unity + tv_denoise_iter_chambolle: bool + Number of iterations with TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + switch_object_iter: int, optional + Iteration to switch object type between 'complex' and 'potential' or between + 'potential' and 'complex' + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + asnumpy = self._asnumpy + xp = self._xp + + # Reconstruction method + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + if self._verbose: + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + if reconstruction_parameter is not None: + if np.array(reconstruction_parameter).shape == (3,): + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + else: + if step_size is not None: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + if max_batch_size is not None: + xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + if reset: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + self._exit_waves = None + self._object_type = self._object_type_initial + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] + self._exit_waves = None + + # main loop + for a0 in tqdmnd( + max_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if a0 == switch_object_iter: + if self._object_type == "potential": + self._object_type = "complex" + self._object = xp.exp(1j * self._object) + elif self._object_type == "complex": + self._object_type = "potential" + self._object = xp.angle(self._object) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( + self._num_diffraction_patterns + ) + positions_px = self._positions_px.copy()[shuffled_indices] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[shuffled_indices[start:end]] + + # forward operator + ( + propagated_probes, + object_patches, + self._transmitted_probes, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + self._probe, + amplitudes, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + propagated_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + self._probe[0], + self._transmitted_probes[:, 0], + amplitudes, + self._positions_px, + positions_step_size, + constrain_position_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._positions_px = positions_px.copy()[unshuffled_indices] + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=a0 < kz_regularization_filter_iter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=a0 < identical_slices_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _visualize_last_iteration_figax( + self, + fig, + object_ax, + convergence_ax, + cbar: bool, + padding: int = 0, + **kwargs, + ): + """ + Displays last reconstructed object on a given fig/ax. + + Parameters + -------- + fig: Figure + Matplotlib figure object_ax lives in + object_ax: Axes + Matplotlib axes to plot reconstructed object in + convergence_ax: Axes, optional + Matplotlib axes to plot convergence plot in + cbar: bool, optional + If true, displays a colorbar + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + cmap = kwargs.pop("cmap", "magma") + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + im = object_ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(object_ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if convergence_ax is not None and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = self.error_iterations + + convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + padding: int, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + """ + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + invert = kwargs.pop("invert", False) + hue_start = kwargs.pop("hue_start", 0) + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe_array = Complex2RGB( + self.probe_fourier[0], hue_start=hue_start, invert=invert + ) + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + self.probe[0], hue_start=hue_start, invert=invert + ) + ax.set_title("Reconstructed probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + + else: + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = np.array(self.error_iterations) + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration Number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + iterations_grid: Tuple[int, int], + padding: int, + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + if iterations_grid == "auto": + num_iter = len(self.error_iterations) + + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + invert = kwargs.pop("invert", False) + hue_start = kwargs.pop("hue_start", 0) + + errors = np.array(self.error_iterations) + + objects = [] + object_type = [] + + for obj in self.object_iterations: + if np.iscomplexobj(obj): + obj = np.angle(obj) + object_type.append("phase") + else: + object_type.append("potential") + objects.append( + self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) + ) + + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + probes = self.probe_iterations + else: + total_grids = np.prod(iterations_grid) + max_iter = len(objects) - 1 + grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[grid_range[n]], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = Complex2RGB( + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0] + ) + ), + hue_start=hue_start, + invert=invert, + ) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + probes[grid_range[n]][0], hue_start=hue_start, invert=invert + ) + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + **kwargs, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], hue_start=hue_start, invert=invert + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration Number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + cbar: bool = True, + padding: int = 0, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + return self + + def show_fourier_probe( + self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list(self.probe_fourier) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + show_complex( + probe if len(probe) > 1 else probe[0], + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + + def show_transmitted_probe( + self, + plot_fourier_probe: bool = False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + transmitted_probe_intensities = xp.sum( + xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) + ) + min_intensity_transmitted = self._transmitted_probes[ + xp.argmin(transmitted_probe_intensities), 0 + ] + max_intensity_transmitted = self._transmitted_probes[ + xp.argmax(transmitted_probe_intensities), 0 + ] + mean_transmitted = self._transmitted_probes[:, 0].mean(0) + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy(self._return_fourier_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + show_complex( + probes, + title=title, + **kwargs, + ) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + vmin = np.min(rotated_object) if common_color_scale else None + vmax = np.max(rotated_object) if common_color_scale else None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + if not plot_line_profile: + fig, ax = plt.subplots() + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[0] * ms_obj.shape[1], + self.sampling[1] * ms_obj.shape[2], + 0, + ] + fig, ax = plt.subplots(2, 1) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def tune_num_slices_and_thicknesses( + self, + num_slices_guess=None, + thicknesses_guess=None, + num_slices_step_size=1, + thicknesses_step_size=20, + num_slices_values=3, + num_thicknesses_values=3, + update_defocus=False, + max_iter=5, + plot_reconstructions=True, + plot_convergence=True, + return_values=False, + **kwargs, + ): + """ + Run reconstructions over a parameters space of number of slices + and slice thicknesses. Should be run after the preprocess step. + + Parameters + ---------- + num_slices_guess: float, optional + initial starting guess for number of slices, rounds to nearest integer + if None, uses current initialized values + thicknesses_guess: float (A), optional + initial starting guess for thicknesses of slices assuming same + thickness for each slice + if None, uses current initialized values + num_slices_step_size: float, optional + size of change of number of slices for each step in parameter space + thicknesses_step_size: float (A), optional + size of change of slice thicknesses for each step in parameter space + num_slices_values: int, optional + number of number of slice values to test, must be >= 1 + num_thicknesses_values: int,optional + number of thicknesses values to test, must be >= 1 + update_defocus: bool, optional + if True, updates defocus based on estimated total thickness + max_iter: int, optional + number of iterations to run in ptychographic reconstruction + plot_reconstructions: bool, optional + if True, plot phase of reconstructed objects + plot_convergence: bool, optional + if True, plots error for each iteration for each reconstruction + return_values: bool, optional + if True, returns objects, convergence + + Returns + ------- + objects: list + reconstructed objects + convergence: np.ndarray + array of convergence values from reconstructions + """ + + # calculate number of slices and thicknesses values to test + if num_slices_guess is None: + num_slices_guess = self._num_slices + if thicknesses_guess is None: + thicknesses_guess = np.mean(self._slice_thicknesses) + + if num_slices_values == 1: + num_slices_step_size = 0 + + if num_thicknesses_values == 1: + thicknesses_step_size = 0 + + num_slices = np.linspace( + num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_values, + ) + + thicknesses = np.linspace( + thicknesses_guess + - thicknesses_step_size * (num_thicknesses_values - 1) / 2, + thicknesses_guess + + thicknesses_step_size * (num_thicknesses_values - 1) / 2, + num_thicknesses_values, + ) + + if return_values: + convergence = [] + objects = [] + + # current initialized values + current_verbose = self._verbose + current_num_slices = self._num_slices + current_thicknesses = self._slice_thicknesses + current_rotation_deg = self._rotation_best_rad * 180 / np.pi + current_transpose = self._rotation_best_transpose + current_defocus = -self._polar_parameters["C10"] + + # Gridspec to plot on + if plot_reconstructions: + if plot_convergence: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values * 2, + height_ratios=[1, 1 / 4] * num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) + ) + else: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) + ) + + fig = plt.figure(figsize=figsize) + + progress_bar = kwargs.pop("progress_bar", False) + # run loop and plot along the way + self._verbose = False + for flat_index, (slices, thickness) in enumerate( + tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") + ): + slices = int(slices) + self._num_slices = slices + self._slice_thicknesses = np.tile(thickness, slices - 1) + self._probe = None + self._object = None + if update_defocus: + defocus = current_defocus + slices / 2 * thickness + self._polar_parameters["C10"] = -defocus + + self.preprocess( + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + ) + self.reconstruct( + reset=True, + store_iterations=True if plot_convergence else False, + max_iter=max_iter, + progress_bar=progress_bar, + **kwargs, + ) + + if plot_reconstructions: + row_index, col_index = np.unravel_index( + flat_index, (num_slices_values, num_thicknesses_values) + ) + + if plot_convergence: + object_ax = fig.add_subplot(spec[row_index * 2, col_index]) + convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=convergence_ax, + cbar=True, + ) + convergence_ax.yaxis.tick_right() + else: + object_ax = fig.add_subplot(spec[row_index, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=None, + cbar=True, + ) + + object_ax.set_title( + f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" + ) + object_ax.set_xticks([]) + object_ax.set_yticks([]) + + if return_values: + objects.append(self.object) + convergence.append(self.error_iterations.copy()) + + # initialize back to pre-tuning values + self._probe = None + self._object = None + self._num_slices = current_num_slices + self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) + self._polar_parameters["C10"] = -current_defocus + self.preprocess( + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + ) + self._verbose = current_verbose + + if plot_reconstructions: + spec.tight_layout(fig) + + if return_values: + return objects, convergence + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + obj = asnumpy(obj) + if np.iscomplexobj(obj): + obj = np.angle(obj) + + obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) From 807ac1503da5bedf5f8e761827fd239a08192db3 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 22 Sep 2023 14:25:45 -0700 Subject: [PATCH 19/62] small reset bug --- .../phase/iterative_mixedstate_multislice_ptychography.py | 2 ++ py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 2 ++ py4DSTEM/process/phase/iterative_multislice_ptychography.py | 2 ++ .../process/phase/iterative_overlap_magnetic_tomography.py | 3 ++- py4DSTEM/process/phase/iterative_overlap_tomography.py | 3 ++- py4DSTEM/process/phase/iterative_simultaneous_ptychography.py | 2 ++ py4DSTEM/process/phase/iterative_singleslice_ptychography.py | 2 ++ 7 files changed, 14 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index acb9f12a2..6620cf71f 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -2249,6 +2249,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index d066c7f3f..6acbf7fc3 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1598,6 +1598,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 3e36978e4..f663b9905 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2118,6 +2118,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 712a35647..3aac7edc7 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -2326,12 +2326,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index fd9af0bb2..79a477da2 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -2143,12 +2143,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index a19fc82d3..584edfa6c 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -2775,6 +2775,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = (None,) * self._num_sim_measurements self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 97c7a3e5d..97e607cb6 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1509,6 +1509,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( From 8b35dbb3ff1955bfec50abc4b932a12ed62f211a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 14:46:25 -0700 Subject: [PATCH 20/62] fixed NaN bug --- .../process/phase/iterative_base_class.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index ae4c92d4b..b7aa61af0 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -484,9 +484,14 @@ def _calculate_intensities_center_of_mass( ) if com_shifts is None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + com_shifts = fit_origin( - (asnumpy(com_measured_x), asnumpy(com_measured_y)), + (com_measured_x_np, com_measured_y_np), fitfunction=fit_function, + mask=finite_mask, ) # Fit function to center of mass @@ -494,12 +499,12 @@ def _calculate_intensities_center_of_mass( com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units - com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[ - 0 - ] - com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[ - 1 - ] + com_normalized_x = ( + xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + ) + com_normalized_y = ( + xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + ) return ( com_measured_x, From 6cb62f1baac4e6e9bf6b773ffb43419d3ff5ffd4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 15:53:47 -0700 Subject: [PATCH 21/62] changed complex plotting --- py4DSTEM/visualize/vis_special.py | 122 +++++++++++++----------------- setup.py | 3 +- 2 files changed, 55 insertions(+), 70 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 43cf7fff8..d46788472 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -1,6 +1,5 @@ from matplotlib import cm, colors as mcolors, pyplot as plt import numpy as np -from matplotlib.colors import hsv_to_rgb from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi @@ -18,9 +17,7 @@ from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes,ax_addaxes_QtoR - - - +from colorspacious import cspace_convert def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, color_ann='y',color_ell='r',alpha_ann=0.2,alpha_ell=0.7, @@ -717,15 +714,21 @@ def show_selected_dps(datacube,positions,im,bragg_pos=None, get_pointcolors=lambda i:colors[i], **kwargs) -def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=False): +def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float) : power to raise amplitude to """ - amp = np.abs(complex_data) + if power is None: + norm = mcolors.Normalize() + else: + norm = mcolors.PowerNorm(power) + + amp = norm(np.abs(complex_data)).data + phase = np.angle(complex_data) + if np.isclose(np.max(amp),np.min(amp)): if vmin is None: vmin = 0 @@ -746,35 +749,37 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=Fals amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) + + J = amp*100 + C = np.where(J<61.5,98*J/123,1400/11-14*J/11) + h = np.rad2deg(phase)+180 - phase = np.angle(complex_data) + np.deg2rad(hue_start) - amp /= np.max(amp) - rgb = np.zeros(phase.shape +(3,)) - rgb[...,0] = 0.5*(np.sin(phase)+1)*amp - rgb[...,1] = 0.5*(np.sin(phase+np.pi/2)+1)*amp - rgb[...,2] = 0.5*(-np.sin(phase)+1)*amp + JCh = np.stack((J,C,h), axis=-1) + rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - return 1-rgb if invert else rgb + return rgb -def add_colorbar_arg(cax, vmin = None, vmax = None, hue_start = 0, invert = False): +def add_colorbar_arg(cax, c = 49, j = 61.5): """ - cax : axis to add cbar too - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + cax : axis to add cbar to + c : constant chroma value + j : constant luminance value """ - z = np.exp(1j * np.linspace(-np.pi, np.pi, 200)) - rgb_vals = Complex2RGB(z, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert) + + h = np.linspace(0, 360, 256,endpoint=False) + J = np.full_like(h,j) + C = np.full_like(h,c) + JCh = np.stack((J,C,h), axis=-1) + rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) - cb1 = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) - cb1.set_label("arg", rotation=0, ha="center", va="bottom") - cb1.ax.yaxis.set_label_coords(0.5, 1.01) - cb1.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) - cb1.set_ticklabels( + cb.set_label("arg", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) + cb.set_ticklabels( [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) @@ -787,8 +792,7 @@ def show_complex( pixelunits="pixels", pixelsize=1, returnfig=False, - hue_start = 0, - invert=False, + power=None, **kwargs ): """ @@ -801,13 +805,12 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels - cbar (bool, optional) : if True, include color wheel + cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - hue_start (float, optional) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float,optional) : power to raise amplitude to Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -817,12 +820,12 @@ def show_complex( ar_complex = ar_complex[0] if (isinstance(ar_complex,list) and len(ar_complex) == 1) else ar_complex if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): - rgb = [Complex2RGB(ar, vmin, vmax, hue_start = hue_start, invert=invert) for sublist in ar_complex for ar in sublist] + rgb = [Complex2RGB(ar, vmin, vmax, power=power) for sublist in ar_complex for ar in sublist] H = len(ar_complex) W = len(ar_complex[0]) else: - rgb = [Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) for ar in ar_complex] + rgb = [Complex2RGB(ar, vmin, vmax, power=power) for ar in ar_complex] if len(rgb[0].shape) == 4: H = len(ar_complex) W = rgb[0].shape[0] @@ -831,7 +834,7 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, hue_start=hue_start, invert=invert) + rgb = Complex2RGB(ar_complex, vmin, vmax, power=power) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -882,37 +885,18 @@ def show_complex( add_scalebar(ax, scalebar) # add color bar - if cbar == True: - ax0 = fig.add_axes([1, 0.35, 0.3, 0.3]) - - # create wheel - AA = 1000 - kx = np.fft.fftshift(np.fft.fftfreq(AA)) - ky = np.fft.fftshift(np.fft.fftfreq(AA)) - kya, kxa = np.meshgrid(ky, kx) - kra = (kya**2 + kxa**2) ** 0.5 - ktheta = np.arctan2(-kxa, kya) - ktheta = kra * np.exp(1j * ktheta) - - # convert to hsv - rgb = Complex2RGB(ktheta, 0, 0.4, hue_start = hue_start, invert=invert) - ind = kra > 0.4 - rgb[ind] = [1, 1, 1] - - # plot - ax0.imshow(rgb) - - # add axes - ax0.axhline(AA / 2, 0, AA, color="k") - ax0.axvline(AA / 2, 0, AA, color="k") - ax0.axis("off") - - label_size = 16 - - ax0.text(AA, AA / 2, 1, fontsize=label_size) - ax0.text(AA / 2, 0, "i", fontsize=label_size) - ax0.text(AA / 2, AA, "-i", fontsize=label_size) - ax0.text(0, AA / 2, -1, fontsize=label_size) - - if returnfig == True: + if cbar: + if is_grid: + for ax_flat in ax.flatten(): + divider = make_axes_locatable(ax_flat) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb) + else: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb) + + fig.tight_layout() + + if returnfig: return fig, ax diff --git a/setup.py b/setup.py index d8baff354..40255d5bf 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,8 @@ 'dask >= 2.3.0', 'distributed >= 2.3.0', 'emdfile >= 0.0.10', - 'pylops >= 2.1.0' + 'pylops >= 2.1.0', + 'colorspacious >= 1.1.2', ], extras_require={ 'ipyparallel': ['ipyparallel >= 6.2.4', 'dill >= 0.3.3'], From ad79416f4545659ddf39bcad05e6d94cbc5e78e1 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 17:15:24 -0700 Subject: [PATCH 22/62] updated complex plotting phase calls --- ...tive_mixedstate_multislice_ptychography.py | 85 ++++++------------- .../iterative_mixedstate_ptychography.py | 48 ++++------- .../iterative_multislice_ptychography.py | 64 ++++---------- .../iterative_overlap_magnetic_tomography.py | 32 ++----- .../phase/iterative_overlap_tomography.py | 61 +++++-------- .../iterative_simultaneous_ptychography.py | 37 +++----- .../iterative_singleslice_ptychography.py | 54 ++++-------- py4DSTEM/visualize/vis_special.py | 8 +- 8 files changed, 124 insertions(+), 265 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6620cf71f..ea10050dd 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -609,19 +609,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered[0], - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -633,10 +625,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -658,38 +647,33 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe[0]") + ax1.set_title("Initial probe[0] intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax2) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe[0]") + ax2.set_title("Propagated probe[0] intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -701,7 +685,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1126,23 +1110,17 @@ def _projection_sets_adjoint( ) if self._object_type == "potential": - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves_copy[:, i_probe] - ) + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] ) ) else: - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] - ) + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] ) probe_normalization = 1 / xp.sqrt( @@ -2519,8 +2497,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -2615,30 +2591,25 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert - ) + probe_array = Complex2RGB(self.probe_fourier[0]) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert - ) - ax.set_title("Reconstructed probe[0]") + probe_array = Complex2RGB(self.probe[0], power=2) + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -2671,10 +2642,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2746,8 +2717,6 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) errors = np.array(self.error_iterations) @@ -2852,15 +2821,14 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, ) ax.set_title(f"Iter: {grid_range[n]} probe[0]") ax.set_ylabel("x [A]") @@ -2869,12 +2837,11 @@ def _visualize_all_iterations( im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2886,7 +2853,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 6acbf7fc3..6dbccff06 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -505,19 +505,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -540,23 +532,19 @@ def preprocess( axs[i].imshow( complex_probe_rgb[i], extent=probe_extent, - **kwargs, ) axs[i].set_ylabel("x [A]") axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial Probe[{i}]") + axs[i].set_title(f"Initial probe[{i}] intensity") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax) axs[-1].imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) axs[-1].scatter( self.positions[:, 1], @@ -568,7 +556,7 @@ def preprocess( axs[-1].set_xlabel("y [A]") axs[-1].set_xlim((extent[0], extent[1])) axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object Field of View") + axs[-1].set_title("Object field of view") fig.tight_layout() @@ -1849,8 +1837,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -1943,29 +1929,29 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert + self.probe_fourier[0], ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert + self.probe[0], + power=2, ) - ax.set_title("Reconstructed probe[0]") + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -1998,10 +1984,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2177,24 +2163,22 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: @@ -2211,7 +2195,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index f663b9905..b3614c0ad 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -577,19 +577,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -601,10 +593,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -626,38 +615,33 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -669,7 +653,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -2387,8 +2371,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -2484,29 +2466,26 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert - ) - ax.set_title("Reconstructed probe") + probe_array = Complex2RGB(self.probe, power=2) + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -2539,10 +2518,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2614,8 +2593,6 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) errors = np.array(self.error_iterations) @@ -2720,29 +2697,24 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert - ) - ax.set_title(f"Iter: {grid_range[n]} probe") + probe_array = Complex2RGB(probes[grid_range[n]], power=2) + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2754,7 +2726,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 3aac7edc7..d2934497c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -808,19 +808,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -832,10 +824,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -857,38 +846,35 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -900,7 +886,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -3092,7 +3078,7 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[-1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 79a477da2..3d5982e9e 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -748,19 +748,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -772,10 +764,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -797,38 +786,35 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -840,7 +826,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -2593,8 +2579,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) asnumpy = self._asnumpy @@ -2696,16 +2680,17 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2718,7 +2703,9 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg( + ax_cb, + ) else: ax = fig.add_subplot(spec[0]) im = ax.imshow( @@ -2747,10 +2734,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2834,8 +2821,6 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) errors = np.array(self.error_iterations) @@ -2950,29 +2935,27 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2984,7 +2967,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 584edfa6c..85e9a0b18 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -746,19 +746,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -780,23 +772,21 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax2.scatter( self.positions[:, 1], @@ -808,7 +798,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -3061,8 +3051,6 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj_e = np.angle(self.object[0]) @@ -3188,29 +3176,26 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert - ) - ax.set_title("Reconstructed probe") + probe_array = Complex2RGB(self.probe, power=2) + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: # Electrostatic Object @@ -3261,10 +3246,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 97e607cb6..3843da983 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -474,19 +474,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -508,23 +500,19 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="gray", ) ax2.scatter( self.positions[:, 1], @@ -536,7 +524,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -1762,8 +1750,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -1856,29 +1842,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -1911,10 +1897,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1985,9 +1971,7 @@ def _visualize_all_iterations( else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "inferno") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + cmap = kwargs.pop("cmap", "magma") errors = np.array(self.error_iterations) @@ -2091,8 +2075,6 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2100,21 +2082,21 @@ def _visualize_all_iterations( else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2126,7 +2108,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index d46788472..722b55800 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -728,7 +728,7 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = norm(np.abs(complex_data)).data phase = np.angle(complex_data) - + if np.isclose(np.max(amp),np.min(amp)): if vmin is None: vmin = 0 @@ -736,9 +736,9 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): vmax = np.max(amp) else: if vmin is None: - vmin = 0.02 + vmin = 0.025 if vmax is None: - vmax = 0.98 + vmax = 0.975 vals = np.sort(amp[~np.isnan(amp)]) ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") @@ -750,7 +750,7 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) - J = amp*100 + J = amp*61.5 # Note we restrict luminance to 61.5 C = np.where(J<61.5,98*J/123,1400/11-14*J/11) h = np.rad2deg(phase)+180 From 356a3a180e4699f2c1db703b250b0d856b76de6a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 17:58:08 -0700 Subject: [PATCH 23/62] adding complex CoM plotting and various dpc plotting bugs --- .../process/phase/iterative_base_class.py | 81 +++++++++++++++---- py4DSTEM/process/phase/iterative_dpc.py | 48 ++++++----- py4DSTEM/visualize/vis_special.py | 6 +- 3 files changed, 96 insertions(+), 39 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index b7aa61af0..96b3d5088 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -295,13 +295,14 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) self._scan_sampling = (1.0, 1.0) self._scan_units = ("pixels",) * 2 @@ -359,13 +360,14 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) self._angular_sampling = (1.0, 1.0) self._angular_units = ("pixels",) * 2 @@ -1134,6 +1136,57 @@ def _normalize_diffraction_intensities( return amplitudes, mean_intensity + def show_complex_CoM( + self, + com=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot complex-valued CoM image + + Parameters + ---------- + + com = (CoM_x, CoM_y) tuple + If None is specified, uses (self.com_x, self.com_y) instead + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A + pixelsize: float, optional + default is scan sampling + """ + + if com is None: + com = (self.com_x, self.com_y) + + if pixelsize is None: + pixelsize = self._scan_sampling[0] + if pixelunits is None: + pixelunits = r"$\AA$" + + figsize = kwargs.pop("figsize", (6, 6)) + fig, ax = plt.subplots(figsize=figsize) + + complex_com = com[0] + 1j * com[1] + + show_complex( + complex_com, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): """ diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4c80ed177..20796160a 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -718,24 +718,26 @@ def reconstruct( xp = self._xp asnumpy = self._asnumpy - if reset is None and hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - self.error_iterations = [] if reset: self.error = np.inf + self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] self.error = getattr(self, "error", np.inf) @@ -770,7 +772,8 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - print(f"Iteration {a0}, step reduced to {self._step_size}") + if self._verbose: + print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -807,10 +810,11 @@ def reconstruct( self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: - warnings.warn( - f"Step-size has decreased below stopping criterion {stopping_criterion}.", - UserWarning, - ) + if self._verbose: + warnings.warn( + f"Step-size has decreased below stopping criterion {stopping_criterion}.", + UserWarning, + ) # crop result self._object_phase = self._padded_object_phase[ @@ -840,7 +844,7 @@ def _visualize_last_iteration( If true, the NMSE error plot is displayed """ - figsize = kwargs.pop("figsize", (8, 8)) + figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") if plot_convergence: @@ -862,7 +866,7 @@ def _visualize_last_iteration( im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC Phase Reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") if cbar: divider = make_axes_locatable(ax1) @@ -870,11 +874,11 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "_error_iterations"): - errors = self._error_iterations + if plot_convergence: + errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -979,7 +983,7 @@ def _visualize_all_iterations( if plot_convergence: ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -990,7 +994,7 @@ def visualize( fig=None, iterations_grid: Tuple[int, int] = None, plot_convergence: bool = True, - cbar: bool = False, + cbar: bool = True, **kwargs, ): """ diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 722b55800..0792ee3c8 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -736,9 +736,9 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): vmax = np.max(amp) else: if vmin is None: - vmin = 0.025 + vmin = 0.0 if vmax is None: - vmax = 0.975 + vmax = 1.0 vals = np.sort(amp[~np.isnan(amp)]) ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") @@ -751,7 +751,7 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = np.where(amp > vmax, vmax, amp) J = amp*61.5 # Note we restrict luminance to 61.5 - C = np.where(J<61.5,98*J/123,1400/11-14*J/11) + C = np.where(J<61.5,98*J/123,1400/11-14*J/11) # Min uniform chroma h = np.rad2deg(phase)+180 JCh = np.stack((J,C,h), axis=-1) From 2bc9da98a6ca3bb0388cd88a7a4a3dd9f35d2dd5 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 29 Sep 2023 17:47:49 -0700 Subject: [PATCH 24/62] parallax descan correct --- py4DSTEM/process/phase/iterative_parallax.py | 36 ++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index b23fe2cae..0d2f4cfc9 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -11,6 +11,7 @@ from emdfile import Custom, tqdmnd from matplotlib.gridspec import GridSpec from py4DSTEM import DataCube +from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom @@ -112,6 +113,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, + descan_correct: bool = False, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, @@ -134,6 +136,8 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + descan_correct: float, optional + If True, aligns bright field stack based on measured descan rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 @@ -180,6 +184,38 @@ def preprocess( raise ValueError( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct + if descan_correct: + from py4DSTEM.process.phase import DPCReconstruction + + dpc = DPCReconstruction( + energy=self._energy, + datacube=self._datacube, + verbose=False, + ).preprocess( + force_com_rotation=0, + force_com_transpose=False, + plot_center_of_mass=False, + ) + + intensities_shifted = self._intensities.copy() + + center_x = np.mean(dpc._com_measured_x) + center_y = np.mean(dpc._com_measured_y) + for rx in range(intensities_shifted.shape[0]): + for ry in range(intensities_shifted.shape[1]): + intensity_shifted = get_shifted_ar( + self._intensities[rx, ry], + -dpc._com_measured_x[rx, ry] + center_x, + -dpc._com_measured_y[rx, ry] + center_y, + bilinear=True, + device="cpu", + ) + + intensities_shifted[rx, ry] = intensity_shifted + + self._intensities = intensities_shifted + self._dp_mean = intensities_shifted.mean((0, 1)) # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) From 3a99d5ae0cadf48d5beb8e829797bab1195557a4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 29 Sep 2023 21:01:55 -0700 Subject: [PATCH 25/62] complex plotting improvements, formatting --- py4DSTEM/visualize/vis_special.py | 921 +++++++++++++++++++----------- 1 file changed, 577 insertions(+), 344 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 0792ee3c8..6dd980bce 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -15,13 +15,25 @@ add_scalebar, ) from py4DSTEM.visualize.vis_grid import show_image_grid -from py4DSTEM.visualize.vis_RQ import ax_addaxes,ax_addaxes_QtoR +from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR from colorspacious import cspace_convert -def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, - color_ann='y',color_ell='r',alpha_ann=0.2,alpha_ell=0.7, - linewidth_ann=2,linewidth_ell=2,returnfig=False,**kwargs): + +def show_elliptical_fit( + ar, + fitradii, + p_ellipse, + fill=True, + color_ann="y", + color_ell="r", + alpha_ann=0.2, + alpha_ell=0.7, + linewidth_ann=2, + linewidth_ell=2, + returnfig=False, + **kwargs +): """ Plots an elliptical curve over its annular fit region. @@ -39,35 +51,55 @@ def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, linewidth_ann: linewidth_ell: """ - Ri,Ro = fitradii - qx0,qy0,a,b,theta = p_ellipse - fig,ax = show(ar, - annulus={'center':(qx0,qy0), - 'radii':(Ri,Ro), - 'fill':fill, - 'color':color_ann, - 'alpha':alpha_ann, - 'linewidth':linewidth_ann}, - ellipse={'center':(qx0,qy0), - 'a':a, - 'b':b, - 'theta':theta, - 'color':color_ell, - 'alpha':alpha_ell, - 'linewidth':linewidth_ell}, - returnfig=True,**kwargs) + Ri, Ro = fitradii + qx0, qy0, a, b, theta = p_ellipse + fig, ax = show( + ar, + annulus={ + "center": (qx0, qy0), + "radii": (Ri, Ro), + "fill": fill, + "color": color_ann, + "alpha": alpha_ann, + "linewidth": linewidth_ann, + }, + ellipse={ + "center": (qx0, qy0), + "a": a, + "b": b, + "theta": theta, + "color": color_ell, + "alpha": alpha_ell, + "linewidth": linewidth_ell, + }, + returnfig=True, + **kwargs, + ) if not returnfig: plt.show() return else: - return fig,ax + return fig, ax -def show_amorphous_ring_fit(dp,fitradii,p_dsg,N=12,cmap=('gray','gray'), - fitborder=True,fitbordercolor='k',fitborderlw=0.5, - scaling='log',ellipse=False,ellipse_color='r', - ellipse_alpha=0.7,ellipse_lw=2,returnfig=False,**kwargs): +def show_amorphous_ring_fit( + dp, + fitradii, + p_dsg, + N=12, + cmap=("gray", "gray"), + fitborder=True, + fitbordercolor="k", + fitborderlw=0.5, + scaling="log", + ellipse=False, + ellipse_color="r", + ellipse_alpha=0.7, + ellipse_lw=2, + returnfig=False, + **kwargs +): """ Display a diffraction pattern with a fit to its amorphous ring, interleaving the data and the fit in a pinwheel pattern. @@ -90,75 +122,112 @@ def show_amorphous_ring_fit(dp,fitradii,p_dsg,N=12,cmap=('gray','gray'), """ from py4DSTEM.process.calibration import double_sided_gaussian from py4DSTEM.process.utils import convert_ellipse_params - assert(len(p_dsg)==11) - assert(isinstance(N,(int,np.integer))) - if isinstance(cmap,tuple): - cmap_data,cmap_fit = cmap[0],cmap[1] + + assert len(p_dsg) == 11 + assert isinstance(N, (int, np.integer)) + if isinstance(cmap, tuple): + cmap_data, cmap_fit = cmap[0], cmap[1] else: - cmap_data,cmap_fit = cmap,cmap - Q_Nx,Q_Ny = dp.shape - qmin,qmax = fitradii + cmap_data, cmap_fit = cmap, cmap + Q_Nx, Q_Ny = dp.shape + qmin, qmax = fitradii # Make coords - qx0,qy0 = p_dsg[6],p_dsg[7] - qyy,qxx = np.meshgrid(np.arange(Q_Ny),np.arange(Q_Nx)) - qx,qy = qxx-qx0,qyy-qy0 - q = np.hypot(qx,qy) - theta = np.arctan2(qy,qx) + qx0, qy0 = p_dsg[6], p_dsg[7] + qyy, qxx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx)) + qx, qy = qxx - qx0, qyy - qy0 + q = np.hypot(qx, qy) + theta = np.arctan2(qy, qx) # Make mask - thetas = np.linspace(-np.pi,np.pi,2*N+1) - pinwheel = np.zeros((Q_Nx,Q_Ny),dtype=bool) + thetas = np.linspace(-np.pi, np.pi, 2 * N + 1) + pinwheel = np.zeros((Q_Nx, Q_Ny), dtype=bool) for i in range(N): - pinwheel += (theta>thetas[2*i]) * (theta<=thetas[2*i+1]) - mask = pinwheel * (q>qmin) * (q<=qmax) + pinwheel += (theta > thetas[2 * i]) * (theta <= thetas[2 * i + 1]) + mask = pinwheel * (q > qmin) * (q <= qmax) # Get fit data fit = double_sided_gaussian(p_dsg, qxx, qyy) # Show - (fig,ax),(vmin,vmax) = show(dp,scaling=scaling,cmap=cmap_data, - mask=np.logical_not(mask),mask_color='empty', - returnfig=True,returnclipvals=True,**kwargs) - show(fit,scaling=scaling,figax=(fig,ax),clipvals='manual',min=vmin,max=vmax, - cmap=cmap_fit,mask=mask,mask_color='empty',**kwargs) + (fig, ax), (vmin, vmax) = show( + dp, + scaling=scaling, + cmap=cmap_data, + mask=np.logical_not(mask), + mask_color="empty", + returnfig=True, + returnclipvals=True, + **kwargs, + ) + show( + fit, + scaling=scaling, + figax=(fig, ax), + clipvals="manual", + min=vmin, + max=vmax, + cmap=cmap_fit, + mask=mask, + mask_color="empty", + **kwargs, + ) if fitborder: - if N%2==1: thetas += (thetas[1]-thetas[0])/2 - if (N//2%2)==0: thetas = np.roll(thetas,-1) + if N % 2 == 1: + thetas += (thetas[1] - thetas[0]) / 2 + if (N // 2 % 2) == 0: + thetas = np.roll(thetas, -1) for i in range(N): - ax.add_patch(Wedge((qy0,qx0),qmax,np.degrees(thetas[2*i]), - np.degrees(thetas[2*i+1]),width=qmax-qmin,fill=None, - color=fitbordercolor,lw=fitborderlw)) + ax.add_patch( + Wedge( + (qy0, qx0), + qmax, + np.degrees(thetas[2 * i]), + np.degrees(thetas[2 * i + 1]), + width=qmax - qmin, + fill=None, + color=fitbordercolor, + lw=fitborderlw, + ) + ) # Add ellipse overlay if ellipse: - A,B,C = p_dsg[8],p_dsg[9],p_dsg[10] - a,b,theta = convert_ellipse_params(A,B,C) - ellipse={'center':(qx0,qy0),'a':a,'b':b,'theta':theta, - 'color':ellipse_color,'alpha':ellipse_alpha,'linewidth':ellipse_lw} - add_ellipses(ax,ellipse) + A, B, C = p_dsg[8], p_dsg[9], p_dsg[10] + a, b, theta = convert_ellipse_params(A, B, C) + ellipse = { + "center": (qx0, qy0), + "a": a, + "b": b, + "theta": theta, + "color": ellipse_color, + "alpha": ellipse_alpha, + "linewidth": ellipse_lw, + } + add_ellipses(ax, ellipse) if not returnfig: plt.show() return else: - return fig,ax + return fig, ax def show_qprofile( q, intensity, ymax=None, - figsize=(12,4), + figsize=(12, 4), returnfig=False, - color='k', - xlabel='q (pixels)', - ylabel='Intensity (A.U.)', + color="k", + xlabel="q (pixels)", + ylabel="Intensity (A.U.)", labelsize=16, ticklabelsize=14, grid=True, label=None, - **kwargs): + **kwargs +): """ Plots a diffraction space radial profile. Params: @@ -174,148 +243,167 @@ def show_qprofile( label a legend label for the plotted curve """ if ymax is None: - ymax = np.max(intensity)*1.05 + ymax = np.max(intensity) * 1.05 - fig,ax = plt.subplots(figsize=figsize) - ax.plot(q,intensity,color=color,label=label) + fig, ax = plt.subplots(figsize=figsize) + ax.plot(q, intensity, color=color, label=label) ax.grid(grid) - ax.set_ylim(0,ymax) - ax.tick_params(axis='x',labelsize=ticklabelsize) + ax.set_ylim(0, ymax) + ax.tick_params(axis="x", labelsize=ticklabelsize) ax.set_yticklabels([]) - ax.set_xlabel(xlabel,size=labelsize) - ax.set_ylabel(ylabel,size=labelsize) + ax.set_xlabel(xlabel, size=labelsize) + ax.set_ylabel(ylabel, size=labelsize) if not returnfig: plt.show() return else: - return fig,ax + return fig, ax -def show_kernel( - kernel, - R, - L, - W, - figsize=(12,6), - returnfig=False, - **kwargs): + +def show_kernel(kernel, R, L, W, figsize=(12, 6), returnfig=False, **kwargs): """ Plots, side by side, the probe kernel and its line profile. R is the kernel plot's window size. L and W are the length and width of the lineprofile. """ - lineprofile_1 = np.concatenate([ - np.sum(kernel[-L:,:W],axis=1), - np.sum(kernel[:L,:W],axis=1) - ]) - lineprofile_2 = np.concatenate([ - np.sum(kernel[:W,-L:],axis=0), - np.sum(kernel[:W,:L],axis=0) - ]) - - im_kernel = np.vstack([ - np.hstack([ - kernel[-int(R):,-int(R):], - kernel[-int(R):,:int(R)] - ]), - np.hstack([ - kernel[:int(R),-int(R):], - kernel[:int(R),:int(R)] - ]), - ]) - - fig,axs = plt.subplots(1,2,figsize=figsize) - axs[0].matshow(im_kernel,cmap='gray') - axs[0].plot( - np.ones(2*R)*R, - np.arange(2*R), - c='r') - axs[0].plot( - np.arange(2*R), - np.ones(2*R)*R, - c='c') - - - axs[1].plot( - np.arange(len(lineprofile_1)), - lineprofile_1, - c='r') - axs[1].plot( - np.arange(len(lineprofile_2)), - lineprofile_2, - c='c') + lineprofile_1 = np.concatenate( + [np.sum(kernel[-L:, :W], axis=1), np.sum(kernel[:L, :W], axis=1)] + ) + lineprofile_2 = np.concatenate( + [np.sum(kernel[:W, -L:], axis=0), np.sum(kernel[:W, :L], axis=0)] + ) + + im_kernel = np.vstack( + [ + np.hstack([kernel[-int(R) :, -int(R) :], kernel[-int(R) :, : int(R)]]), + np.hstack([kernel[: int(R), -int(R) :], kernel[: int(R), : int(R)]]), + ] + ) + + fig, axs = plt.subplots(1, 2, figsize=figsize) + axs[0].matshow(im_kernel, cmap="gray") + axs[0].plot(np.ones(2 * R) * R, np.arange(2 * R), c="r") + axs[0].plot(np.arange(2 * R), np.ones(2 * R) * R, c="c") + + axs[1].plot(np.arange(len(lineprofile_1)), lineprofile_1, c="r") + axs[1].plot(np.arange(len(lineprofile_2)), lineprofile_2, c="c") if not returnfig: plt.show() return else: - return fig,axs + return fig, axs -def show_voronoi(ar,x,y,color_points='r',color_lines='w',max_dist=None, - returnfig=False,**kwargs): + +def show_voronoi( + ar, + x, + y, + color_points="r", + color_lines="w", + max_dist=None, + returnfig=False, + **kwargs +): """ words """ from py4DSTEM.process.utils import get_voronoi_vertices - Nx,Ny = ar.shape - points = np.vstack((x,y)).T + + Nx, Ny = ar.shape + points = np.vstack((x, y)).T voronoi = Voronoi(points) - vertices = get_voronoi_vertices(voronoi,Nx,Ny) + vertices = get_voronoi_vertices(voronoi, Nx, Ny) if max_dist is None: - fig,ax = show(ar,returnfig=True,**kwargs) + fig, ax = show(ar, returnfig=True, **kwargs) else: - centers = [(x[i],y[i]) for i in range(len(x))] - fig,ax = show(ar,returnfig=True,**kwargs, - circle={'center':centers,'R':max_dist,'fill':False,'color':color_points}) + centers = [(x[i], y[i]) for i in range(len(x))] + fig, ax = show( + ar, + returnfig=True, + **kwargs, + circle={ + "center": centers, + "R": max_dist, + "fill": False, + "color": color_points, + }, + ) - ax.scatter(voronoi.points[:,1],voronoi.points[:,0],color=color_points) + ax.scatter(voronoi.points[:, 1], voronoi.points[:, 0], color=color_points) for region in range(len(vertices)): vertices_curr = vertices[region] for i in range(len(vertices_curr)): - x0,y0 = vertices_curr[i,:] - xf,yf = vertices_curr[(i+1)%len(vertices_curr),:] - ax.plot((y0,yf),(x0,xf),color=color_lines) - ax.set_xlim([0,Ny]) - ax.set_ylim([0,Nx]) + x0, y0 = vertices_curr[i, :] + xf, yf = vertices_curr[(i + 1) % len(vertices_curr), :] + ax.plot((y0, yf), (x0, xf), color=color_lines) + ax.set_xlim([0, Ny]) + ax.set_ylim([0, Nx]) plt.gca().invert_yaxis() if not returnfig: plt.show() return else: - return fig,ax + return fig, ax + -def show_class_BPs(ar,x,y,s,s2,color='r',color2='y',**kwargs): +def show_class_BPs(ar, x, y, s, s2, color="r", color2="y", **kwargs): """ words """ N = len(x) - assert(N==len(y)==len(s)) + assert N == len(y) == len(s) - fig,ax = show(ar,returnfig=True,**kwargs) - ax.scatter(y,x,s=s2,color=color2) - ax.scatter(y,x,s=s,color=color) + fig, ax = show(ar, returnfig=True, **kwargs) + ax.scatter(y, x, s=s2, color=color2) + ax.scatter(y, x, s=s, color=color) plt.show() return -def show_class_BPs_grid(ar,H,W,x,y,get_s,s2,color='r',color2='y',returnfig=False, - axsize=(6,6),titlesize=0,get_bordercolor=None,**kwargs): + +def show_class_BPs_grid( + ar, + H, + W, + x, + y, + get_s, + s2, + color="r", + color2="y", + returnfig=False, + axsize=(6, 6), + titlesize=0, + get_bordercolor=None, + **kwargs +): """ words """ - fig,axs = show_image_grid(lambda i:ar,H,W,axsize=axsize,titlesize=titlesize, - get_bordercolor=get_bordercolor,returnfig=True,**kwargs) + fig, axs = show_image_grid( + lambda i: ar, + H, + W, + axsize=axsize, + titlesize=titlesize, + get_bordercolor=get_bordercolor, + returnfig=True, + **kwargs, + ) for i in range(H): for j in range(W): - ax = axs[i,j] - N = i*W+j + ax = axs[i, j] + N = i * W + j s = get_s(N) - ax.scatter(y,x,s=s2,color=color2) - ax.scatter(y,x,s=s,color=color) + ax.scatter(y, x, s=s2, color=color2) + ax.scatter(y, x, s=s, color=color) if not returnfig: plt.show() return else: - return fig,axs + return fig, axs + def show_strain( strainmap, @@ -323,10 +411,10 @@ def show_strain( vrange_theta, vrange_exy=None, vrange_eyy=None, - flip_theta = False, + flip_theta=False, bkgrd=True, - show_cbars=('exx','eyy','exy','theta'), - bordercolor='k', + show_cbars=("exx", "eyy", "exy", "theta"), + bordercolor="k", borderwidth=1, titlesize=24, ticklabelsize=16, @@ -339,20 +427,21 @@ def show_strain( xaxis_y=0, axes_length=10, axes_width=1, - axes_color='r', - xaxis_space='Q', + axes_color="r", + xaxis_space="Q", labelaxes=True, QR_rotation=0, axes_labelsize=12, - axes_labelcolor='r', - axes_plots=('exx'), - cmap='RdBu_r', + axes_labelcolor="r", + axes_plots=("exx"), + cmap="RdBu_r", layout=0, - figsize=(12,12), - returnfig=False): + figsize=(12, 12), + returnfig=False, +): """ Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') + masking each image with strainmap.get_slice('mask') Args: strainmap (RealSlice): @@ -360,7 +449,7 @@ def show_strain( vrange_theta (length 2 list or tuple): vrange_exy (length 2 list or tuple): vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle + flip_theta (bool): if True, take negative of angle bkgrd (bool): show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a tuple containing any, all, or none of ('exx','eyy','exy','theta'). @@ -394,11 +483,11 @@ def show_strain( returnfig (bool): """ # Lookup table for different layouts - assert(layout in (0,1,2)) + assert layout in (0, 1, 2) layout_lookup = { - 0:['left','right','left','right'], - 1:['bottom','bottom','bottom','bottom'], - 2:['right','right','right','right'], + 0: ["left", "right", "left", "right"], + 1: ["bottom", "bottom", "bottom", "bottom"], + 2: ["right", "right", "right", "right"], } layout_p = layout_lookup[layout] @@ -407,141 +496,204 @@ def show_strain( vrange_exy = vrange_exx if vrange_eyy is None: vrange_eyy = vrange_exx - for vrange in (vrange_exx,vrange_eyy,vrange_exy,vrange_theta): - assert(len(vrange)==2), 'vranges must have length 2' - vmin_exx,vmax_exx = vrange_exx[0]/100.,vrange_exx[1]/100. - vmin_eyy,vmax_eyy = vrange_eyy[0]/100.,vrange_eyy[1]/100. - vmin_exy,vmax_exy = vrange_exy[0]/100.,vrange_exy[1]/100. + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 # theta is plotted in units of degrees - vmin_theta,vmax_theta = vrange_theta[0]/(180.0/np.pi),vrange_theta[1]/(180.0/np.pi) + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) # Get images - e_xx = np.ma.array(strainmap.get_slice('e_xx').data,mask=strainmap.get_slice('mask').data==False) - e_yy = np.ma.array(strainmap.get_slice('e_yy').data,mask=strainmap.get_slice('mask').data==False) - e_xy = np.ma.array(strainmap.get_slice('e_xy').data,mask=strainmap.get_slice('mask').data==False) - theta = np.ma.array(strainmap.get_slice('theta').data,mask=strainmap.get_slice('mask').data==False) - if flip_theta == True: - theta = - theta + e_xx = np.ma.array( + strainmap.get_slice("e_xx").data, mask=strainmap.get_slice("mask").data == False + ) + e_yy = np.ma.array( + strainmap.get_slice("e_yy").data, mask=strainmap.get_slice("mask").data == False + ) + e_xy = np.ma.array( + strainmap.get_slice("e_xy").data, mask=strainmap.get_slice("mask").data == False + ) + theta = np.ma.array( + strainmap.get_slice("theta").data, + mask=strainmap.get_slice("mask").data == False, + ) + if flip_theta == True: + theta = -theta # Plot - if layout==0: - fig,((ax11,ax12),(ax21,ax22)) = plt.subplots(2,2,figsize=figsize) - elif layout==1: - fig,(ax11,ax12,ax21,ax22) = plt.subplots(1,4,figsize=figsize) + if layout == 0: + fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + elif layout == 1: + fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) else: - fig,(ax11,ax12,ax21,ax22) = plt.subplots(4,1,figsize=figsize) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) cax11 = show( e_xx, - figax=(fig,ax11), + figax=(fig, ax11), vmin=vmin_exx, vmax=vmax_exx, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) + returncax=True, + ) cax12 = show( e_yy, - figax=(fig,ax12), + figax=(fig, ax12), vmin=vmin_eyy, vmax=vmax_eyy, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) + returncax=True, + ) cax21 = show( e_xy, - figax=(fig,ax21), + figax=(fig, ax21), vmin=vmin_exy, vmax=vmax_exy, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) + returncax=True, + ) cax22 = show( theta, - figax=(fig,ax22), + figax=(fig, ax22), vmin=vmin_theta, vmax=vmax_theta, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) - ax11.set_title(r'$\epsilon_{xx}$',size=titlesize) - ax12.set_title(r'$\epsilon_{yy}$',size=titlesize) - ax21.set_title(r'$\epsilon_{xy}$',size=titlesize) - ax22.set_title(r'$\theta$',size=titlesize) + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) # Add black background if bkgrd: mask = np.ma.masked_where( - strainmap.get_slice('mask').data.astype(bool), - np.zeros_like(strainmap.get_slice('mask').data)) - ax11.matshow(mask,cmap='gray') - ax12.matshow(mask,cmap='gray') - ax21.matshow(mask,cmap='gray') - ax22.matshow(mask,cmap='gray') + strainmap.get_slice("mask").data.astype(bool), + np.zeros_like(strainmap.get_slice("mask").data), + ) + ax11.matshow(mask, cmap="gray") + ax12.matshow(mask, cmap="gray") + ax21.matshow(mask, cmap="gray") + ax22.matshow(mask, cmap="gray") # Colorbars - show_cbars = np.array(['exx' in show_cbars,'eyy' in show_cbars, - 'exy' in show_cbars,'theta' in show_cbars]) + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) if np.any(show_cbars): divider11 = make_axes_locatable(ax11) divider12 = make_axes_locatable(ax12) divider21 = make_axes_locatable(ax21) divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0],size="4%",pad=0.15) - cbax12 = divider12.append_axes(layout_p[1],size="4%",pad=0.15) - cbax21 = divider21.append_axes(layout_p[2],size="4%",pad=0.15) - cbax22 = divider22.append_axes(layout_p[3],size="4%",pad=0.15) - for (ind,show_cbar,cax,cbax,vmin,vmax,tickside,tickunits) in zip( + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( range(4), show_cbars, - (cax11,cax12,cax21,cax22), - (cbax11,cbax12,cbax21,cbax22), - (vmin_exx,vmin_eyy,vmin_exy,vmin_theta), - (vmax_exx,vmax_eyy,vmax_exy,vmax_theta), - (layout_p[0],layout_p[1],layout_p[2],layout_p[3]), - ('% ',' %','% ',r' $^\circ$')): + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): if show_cbar: - ticks = np.linspace(vmin,vmax,ticknumber,endpoint=True) + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) if ind < 3: - ticklabels = np.round(np.linspace( - 100*vmin,100*vmax,ticknumber,endpoint=True),decimals=2).astype(str) + ticklabels = np.round( + np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), + decimals=2, + ).astype(str) else: - ticklabels = np.round(np.linspace( - (180/np.pi)*vmin,(180/np.pi)*vmax,ticknumber,endpoint=True),decimals=2).astype(str) - - if tickside in ('left','right'): - cb = plt.colorbar(cax,cax=cbax,ticks=ticks,orientation='vertical') - cb.ax.set_yticklabels(ticklabels,size=ticklabelsize) + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits,size=unitlabelsize,rotation=0) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) cbax.yaxis.set_label_position(tickside) else: - cb = plt.colorbar(cax,cax=cbax,ticks=ticks,orientation='horizontal') - cb.ax.set_xticklabels(ticklabels,size=ticklabelsize) + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits,size=unitlabelsize,rotation=0) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) cbax.xaxis.set_label_position(tickside) else: - cbax.axis('off') + cbax.axis("off") # Add coordinate axes if show_axes: - assert(xaxis_space in ('R','Q')), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array(['exx' in axes_plots,'eyy' in axes_plots, - 'exy' in axes_plots,'theta' in axes_plots]) - for _show,_ax in zip(show_which_axes,(ax11,ax12,ax21,ax22)): + assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" + show_which_axes = np.array( + [ + "exx" in axes_plots, + "eyy" in axes_plots, + "exy" in axes_plots, + "theta" in axes_plots, + ] + ) + for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): if _show: - if xaxis_space=='R': - ax_addaxes(_ax,xaxis_x,xaxis_y,axes_length,axes_x0,axes_y0, - width=axes_width,color=axes_color,labelaxes=labelaxes, - labelsize=axes_labelsize,labelcolor=axes_labelcolor) + if xaxis_space == "R": + ax_addaxes( + _ax, + xaxis_x, + xaxis_y, + axes_length, + axes_x0, + axes_y0, + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) else: - ax_addaxes_QtoR(_ax,xaxis_x,xaxis_y,axes_length,axes_x0,axes_y0,QR_rotation, - width=axes_width,color=axes_color,labelaxes=labelaxes, - labelsize=axes_labelsize,labelcolor=axes_labelcolor) + ax_addaxes_QtoR( + _ax, + xaxis_x, + xaxis_y, + axes_length, + axes_x0, + axes_y0, + QR_rotation, + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) # Add borders if bordercolor is not None: - for ax in (ax11,ax12,ax21,ax22): - for s in ['bottom','top','left','right']: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: ax.spines[s].set_color(bordercolor) ax.spines[s].set_linewidth(borderwidth) ax.set_xticks([]) @@ -551,54 +703,87 @@ def show_strain( plt.show() return else: - axs = ((ax11,ax12),(ax21,ax22)) - return fig,axs + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs -def show_pointlabels(ar,x,y,color='lightblue',size=20,alpha=1,returnfig=False,**kwargs): +def show_pointlabels( + ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs +): """ Show enumerated index labels for a set of points """ - fig,ax = show(ar,returnfig=True,**kwargs) - d = {'x':x,'y':y,'size':size,'color':color,'alpha':alpha} - add_pointlabels(ax,d) + fig, ax = show(ar, returnfig=True, **kwargs) + d = {"x": x, "y": y, "size": size, "color": color, "alpha": alpha} + add_pointlabels(ax, d) if returnfig: - return fig,ax + return fig, ax else: plt.show() return -def select_point(ar,x,y,i,color='lightblue',color_selected='r',size=20,returnfig=False,**kwargs): +def select_point( + ar, + x, + y, + i, + color="lightblue", + color_selected="r", + size=20, + returnfig=False, + **kwargs +): """ Show enumerated index labels for a set of points, with one selected point highlighted """ - fig,ax = show(ar,returnfig=True,**kwargs) - d1 = {'x':x,'y':y,'size':size,'color':color} - d2 = {'x':x[i],'y':y[i],'size':size,'color':color_selected,'fontweight':'bold'} - add_pointlabels(ax,d1) - add_pointlabels(ax,d2) + fig, ax = show(ar, returnfig=True, **kwargs) + d1 = {"x": x, "y": y, "size": size, "color": color} + d2 = { + "x": x[i], + "y": y[i], + "size": size, + "color": color_selected, + "fontweight": "bold", + } + add_pointlabels(ax, d1) + add_pointlabels(ax, d2) if returnfig: - return fig,ax + return fig, ax else: plt.show() return -def show_max_peak_spacing(ar,spacing,braggdirections,color='g',lw=2,returnfig=False,**kwargs): - """ Show a circle of radius `spacing` about each Bragg direction - """ - centers = [(braggdirections.data['qx'][i],braggdirections.data['qy'][i]) for i in range(braggdirections.length)] - fig,ax = show(ar,circle={'center':centers,'R':spacing,'color':color,'fill':False,'lw':lw}, - returnfig=True,**kwargs) +def show_max_peak_spacing( + ar, spacing, braggdirections, color="g", lw=2, returnfig=False, **kwargs +): + """Show a circle of radius `spacing` about each Bragg direction""" + centers = [ + (braggdirections.data["qx"][i], braggdirections.data["qy"][i]) + for i in range(braggdirections.length) + ] + fig, ax = show( + ar, + circle={ + "center": centers, + "R": spacing, + "color": color, + "fill": False, + "lw": lw, + }, + returnfig=True, + **kwargs, + ) if returnfig: - return fig,ax + return fig, ax else: plt.show() return + def show_origin_meas(data): """ Show the measured positions of the origin. @@ -608,17 +793,19 @@ def show_origin_meas(data): """ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube - if isinstance(data,tuple): - assert len(data)==2 - qx,qy = data - elif isinstance(data,DataCube): - qx,qy = data.calibration.get_origin_meas() - elif isinstance(data,Calibration): - qx,qy = data.get_origin_meas() + + if isinstance(data, tuple): + assert len(data) == 2 + qx, qy = data + elif isinstance(data, DataCube): + qx, qy = data.calibration.get_origin_meas() + elif isinstance(data, Calibration): + qx, qy = data.get_origin_meas() else: raise Exception("data must be of type Datacube or Calibration or tuple") - show_image_grid(get_ar = lambda i:[qx,qy][i],H=1,W=2,cmap='RdBu') + show_image_grid(get_ar=lambda i: [qx, qy][i], H=1, W=2, cmap="RdBu") + def show_origin_fit(data): """ @@ -630,29 +817,49 @@ def show_origin_fit(data): """ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube - if isinstance(data,tuple): - assert len(data)==3 - qx0_meas,qy_meas = data[0] - qx0_fit,qy0_fit = data[1] - qx0_residuals,qy0_residuals = data[2] - elif isinstance(data,DataCube): - qx0_meas,qy0_meas = data.calibration.get_origin_meas() - qx0_fit,qy0_fit = data.calibration.get_origin() - qx0_residuals,qy0_residuals = data.calibration.get_origin_residuals() - elif isinstance(data,Calibration): - qx0_meas,qy0_meas = data.get_origin_meas() - qx0_fit,qy0_fit = data.get_origin() - qx0_residuals,qy0_residuals = data.get_origin_residuals() + + if isinstance(data, tuple): + assert len(data) == 3 + qx0_meas, qy_meas = data[0] + qx0_fit, qy0_fit = data[1] + qx0_residuals, qy0_residuals = data[2] + elif isinstance(data, DataCube): + qx0_meas, qy0_meas = data.calibration.get_origin_meas() + qx0_fit, qy0_fit = data.calibration.get_origin() + qx0_residuals, qy0_residuals = data.calibration.get_origin_residuals() + elif isinstance(data, Calibration): + qx0_meas, qy0_meas = data.get_origin_meas() + qx0_fit, qy0_fit = data.get_origin() + qx0_residuals, qy0_residuals = data.get_origin_residuals() else: raise Exception("data must be of type Datacube or Calibration or tuple") - show_image_grid(get_ar = lambda i:[qx0_meas,qx0_fit,qx0_residuals, - qy0_meas,qy0_fit,qy0_residuals][i], - H=2,W=3,cmap='RdBu') + show_image_grid( + get_ar=lambda i: [ + qx0_meas, + qx0_fit, + qx0_residuals, + qy0_meas, + qy0_fit, + qy0_residuals, + ][i], + H=2, + W=3, + cmap="RdBu", + ) -def show_selected_dps(datacube,positions,im,bragg_pos=None, - colors=None,HW=None,figsize_im=(6,6),figsize_dp=(4,4), - **kwargs): + +def show_selected_dps( + datacube, + positions, + im, + bragg_pos=None, + colors=None, + HW=None, + figsize_im=(6, 6), + figsize_dp=(4, 4), + **kwargs +): """ Shows two plots: first, a real space image overlaid with colored dots at the specified positions; second, a grid of diffraction patterns @@ -673,72 +880,87 @@ def show_selected_dps(datacube,positions,im,bragg_pos=None, *diffraction patterns*. Default is `scaling='log'` """ from py4DSTEM.datacube import DataCube - assert isinstance(datacube,DataCube) + + assert isinstance(datacube, DataCube) N = len(positions) - assert(all([len(x)==2 for x in positions])), "Improperly formated argument `positions`" + assert all( + [len(x) == 2 for x in positions] + ), "Improperly formated argument `positions`" if bragg_pos is not None: show_disk_pos = True - assert(len(bragg_pos)==N) + assert len(bragg_pos) == N else: show_disk_pos = False if colors is None: from matplotlib.cm import gist_ncar - linsp = np.linspace(0,1,N,endpoint=False) + + linsp = np.linspace(0, 1, N, endpoint=False) colors = [gist_ncar(i) for i in linsp] - assert(len(colors)==N), "Number of positions and colors don't match" + assert len(colors) == N, "Number of positions and colors don't match" from matplotlib.colors import is_color_like - assert([is_color_like(i) for i in colors]) + + assert [is_color_like(i) for i in colors] if HW is None: W = int(np.ceil(np.sqrt(N))) - if W<3: W=3 - H = int(np.ceil(N/W)) + if W < 3: + W = 3 + H = int(np.ceil(N / W)) else: - H,W = HW - assert(all([isinstance(x,(int,np.integer)) for x in (H,W)])) + H, W = HW + assert all([isinstance(x, (int, np.integer)) for x in (H, W)]) x = [i[0] for i in positions] y = [i[1] for i in positions] - if 'scaling' not in kwargs.keys(): - kwargs['scaling'] = 'log' + if "scaling" not in kwargs.keys(): + kwargs["scaling"] = "log" if not show_disk_pos: - fig,ax = show(im,figsize=figsize_im,returnfig=True) - add_points(ax,d = {'x':x,'y':y,'pointcolor':colors}) - show_image_grid(get_ar=lambda i:datacube.data[x[i],y[i],:,:],H=H,W=W, - get_bordercolor=lambda i:colors[i],axsize=figsize_dp, - **kwargs) + fig, ax = show(im, figsize=figsize_im, returnfig=True) + add_points(ax, d={"x": x, "y": y, "pointcolor": colors}) + show_image_grid( + get_ar=lambda i: datacube.data[x[i], y[i], :, :], + H=H, + W=W, + get_bordercolor=lambda i: colors[i], + axsize=figsize_dp, + **kwargs, + ) else: - show_image_grid(get_ar=lambda i:datacube.data[x[i],y[i],:,:],H=H,W=W, - get_bordercolor=lambda i:colors[i],axsize=figsize_dp, - get_x=lambda i:bragg_pos[i].data['qx'], - get_y=lambda i:bragg_pos[i].data['qy'], - get_pointcolors=lambda i:colors[i], - **kwargs) + show_image_grid( + get_ar=lambda i: datacube.data[x[i], y[i], :, :], + H=H, + W=W, + get_bordercolor=lambda i: colors[i], + axsize=figsize_dp, + get_x=lambda i: bragg_pos[i].data["qx"], + get_y=lambda i: bragg_pos[i].data["qy"], + get_pointcolors=lambda i: colors[i], + **kwargs, + ) + -def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): """ complex_data (array): complex array to plot - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value + vmin (float) : minimum absolute value + vmax (float) : maximum absolute value power (float) : power to raise amplitude to """ - if power is None: - norm = mcolors.Normalize() - else: - norm = mcolors.PowerNorm(power) - - amp = norm(np.abs(complex_data)).data + amp = np.abs(complex_data) phase = np.angle(complex_data) - if np.isclose(np.max(amp),np.min(amp)): + if power is not None: + amp = amp**power + + if np.isclose(np.max(amp), np.min(amp)): if vmin is None: vmin = 0 if vmax is None: vmax = np.max(amp) else: if vmin is None: - vmin = 0.0 + vmin = 0.02 if vmax is None: - vmax = 1.0 + vmax = 0.98 vals = np.sort(amp[~np.isnan(amp)]) ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") @@ -749,27 +971,29 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) - - J = amp*61.5 # Note we restrict luminance to 61.5 - C = np.where(J<61.5,98*J/123,1400/11-14*J/11) # Min uniform chroma - h = np.rad2deg(phase)+180 + amp = ((amp - vmin) / vmax).clip(1e-16, 1) - JCh = np.stack((J,C,h), axis=-1) + J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff + C = np.where(J < 61.5, 98 * J / 123, 1400 / 11 - 14 * J / 11) # Min uniform chroma + h = np.rad2deg(phase) + 180 + + JCh = np.stack((J, C, h), axis=-1) rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - + return rgb -def add_colorbar_arg(cax, c = 49, j = 61.5): + +def add_colorbar_arg(cax, c=49, j=61.5): """ cax : axis to add cbar to c : constant chroma value j : constant luminance value """ - h = np.linspace(0, 360, 256,endpoint=False) - J = np.full_like(h,j) - C = np.full_like(h,c) - JCh = np.stack((J,C,h), axis=-1) + h = np.linspace(0, 360, 256, endpoint=False) + J = np.full_like(h, j) + C = np.full_like(h, c) + JCh = np.stack((J, C, h), axis=-1) rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) @@ -783,6 +1007,7 @@ def add_colorbar_arg(cax, c = 49, j = 61.5): [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) + def show_complex( ar_complex, vmin=None, @@ -803,7 +1028,7 @@ def show_complex( such as [array1, array2], then arrays are horizonally plotted in one figure vmin (float, optional) : minimum absolute value vmax (float, optional) : maximum absolute value - if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, + if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar @@ -811,16 +1036,24 @@ def show_complex( pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) power (float,optional) : power to raise amplitude to - + Returns: if returnfig==False (default), the figure is plotted and nothing is returned. if returnfig==True, return the figure and the axis. """ # convert to complex colors - ar_complex = ar_complex[0] if (isinstance(ar_complex,list) and len(ar_complex) == 1) else ar_complex + ar_complex = ( + ar_complex[0] + if (isinstance(ar_complex, list) and len(ar_complex) == 1) + else ar_complex + ) if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): - rgb = [Complex2RGB(ar, vmin, vmax, power=power) for sublist in ar_complex for ar in sublist] + rgb = [ + Complex2RGB(ar, vmin, vmax, power=power) + for sublist in ar_complex + for ar in sublist + ] H = len(ar_complex) W = len(ar_complex[0]) @@ -843,7 +1076,7 @@ def show_complex( is_grid = True H = rgb.shape[0] W = rgb.shape[1] - rgb = rgb.reshape((-1,)+rgb.shape[-3:]) + rgb = rgb.reshape((-1,) + rgb.shape[-3:]) else: is_grid = False # plot From d4099ea28f5c8579f9ca549644b8c6029bcacb1b Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 3 Oct 2023 17:10:47 -0700 Subject: [PATCH 26/62] change to fitted intensities --- py4DSTEM/process/phase/iterative_parallax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 0d2f4cfc9..3882fb1e4 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -200,14 +200,14 @@ def preprocess( intensities_shifted = self._intensities.copy() - center_x = np.mean(dpc._com_measured_x) - center_y = np.mean(dpc._com_measured_y) + center_x = np.mean(dpc._com_fitted_x) + center_y = np.mean(dpc._com_fitted_y) for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): intensity_shifted = get_shifted_ar( self._intensities[rx, ry], - -dpc._com_measured_x[rx, ry] + center_x, - -dpc._com_measured_y[rx, ry] + center_y, + -dpc._com_fitted_x[rx, ry] + center_x, + -dpc._com_fitted_y[rx, ry] + center_y, bilinear=True, device="cpu", ) From f2c21d5581febf3fddb44a020305d89c55b38cca Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 8 Oct 2023 08:56:07 -0700 Subject: [PATCH 27/62] preprocessing dtype bug --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 96b3d5088..4ccb21226 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1108,7 +1108,7 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros_like(diffraction_intensities) + amplitudes = xp.zeros(diffraction_intensities.shape, dtype=xp.float32) region_of_interest_shape = diffraction_intensities.shape[-2:] com_fitted_x = self._asnumpy(com_fitted_x) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 6dbccff06..2a482faf0 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1283,6 +1283,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, + constrain_position_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -1358,6 +1359,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1680,6 +1683,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, + constrain_position_distance, ) error += batch_error From 7c3a0d8ebd8bb3b12cc421b4c1ce2f3d9c88fca6 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 10 Oct 2023 08:25:12 -0700 Subject: [PATCH 28/62] adding tilt to propagators --- .../iterative_multislice_ptychography.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index b3614c0ad..365bd8b8f 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -80,6 +80,10 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -111,6 +115,8 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, object_type: str = "complex", verbose: bool = True, device: str = "cpu", @@ -191,6 +197,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -198,6 +206,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -213,6 +223,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) Returns ------- @@ -232,6 +246,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -239,6 +257,10 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp(1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))) return propagators @@ -561,6 +583,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -1859,6 +1883,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional From 23204aa5cee2243a879f9d5a5602f9e26dbb17af Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 12 Oct 2023 14:22:41 -0400 Subject: [PATCH 29/62] single slice crop patterns --- .../process/phase/iterative_base_class.py | 49 +++++++++++++++++-- .../iterative_multislice_ptychography.py | 4 +- .../iterative_singleslice_ptychography.py | 12 +++-- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 4ccb21226..2e6e0a917 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1084,6 +1084,7 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1096,6 +1097,8 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns Returns ------- @@ -1108,13 +1111,46 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros(diffraction_intensities.shape, dtype=xp.float32) - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities = self._asnumpy(diffraction_intensities) + if crop_patterns: + crop_x = int( + np.minimum( + diffraction_intensities.shape[2] - com_fitted_x.max(), + com_fitted_x.min(), + ) + ) + crop_y = int( + np.minimum( + diffraction_intensities.shape[3] - com_fitted_y.max(), + com_fitted_y.min(), + ) + ) + + crop_w = np.minimum(crop_y, crop_x) + region_of_interest_shape = (crop_w * 2, crop_w * 2) + amplitudes = np.zeros( + ( + diffraction_intensities.shape[0], + diffraction_intensities.shape[1], + crop_w * 2, + crop_w * 2, + ), + dtype=np.float32, + ) + + crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask[:crop_w, :crop_w] = True + crop_mask[-crop_w:, :crop_w] = True + crop_mask[:crop_w:, -crop_w:] = True + crop_mask[-crop_w:, -crop_w:] = True + self._crop_mask = crop_mask + + else: + region_of_interest_shape = diffraction_intensities.shape[-2:] + amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) - diffraction_intensities = self._asnumpy(diffraction_intensities) - amplitudes = self._asnumpy(amplitudes) for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): @@ -1126,6 +1162,11 @@ def _normalize_diffraction_intensities( device="cpu", ) + if crop_patterns: + intensities = intensities[crop_mask].reshape( + region_of_interest_shape + ) + mean_intensity += np.sum(intensities) amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 365bd8b8f..be6cbd6ed 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -260,7 +260,9 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) ) - propagators[i] *= xp.exp(1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 3843da983..0cc6b65d5 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -188,6 +188,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -245,6 +246,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns Returns -------- @@ -330,9 +333,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -412,6 +413,11 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, From 84a2067238f0d512856b139866050a46c774ac21 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 13 Oct 2023 03:17:38 -0700 Subject: [PATCH 30/62] tv_denoise typo --- py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py | 2 +- py4DSTEM/process/phase/iterative_overlap_tomography.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index d2934497c..5a1c5dde3 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -2678,7 +2678,7 @@ def reconstruct( else None, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, - v_denoise_inner_iter=tv_denoise_inner_iter, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 3d5982e9e..0157fa422 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -2398,7 +2398,7 @@ def reconstruct( else None, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, - v_denoise_inner_iter=tv_denoise_inner_iter, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) From 0978c4b4ff928acc42487e7f91a8cd41afda50d8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 13 Oct 2023 04:28:09 -0700 Subject: [PATCH 31/62] revisting casting inconsistencies --- .../process/phase/iterative_base_class.py | 10 ++-- py4DSTEM/process/phase/iterative_parallax.py | 46 +++++++++++++------ 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 4ccb21226..4b9d905d1 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -278,7 +278,9 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - intensities = datacube.data + xp = self._xp + + intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -450,8 +452,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - intensities = xp.asarray(intensities, dtype=xp.float32) - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -1108,7 +1108,7 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros(diffraction_intensities.shape, dtype=xp.float32) + amplitudes = xp.zeros_like(diffraction_intensities) region_of_interest_shape = diffraction_intensities.shape[-2:] com_fitted_x = self._asnumpy(com_fitted_x) @@ -1129,8 +1129,6 @@ def _normalize_diffraction_intensities( mean_intensity += np.sum(intensities) amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - amplitudes = xp.asarray(amplitudes, dtype=xp.float32) - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) mean_intensity /= amplitudes.shape[0] diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 3882fb1e4..aa2a4a6b0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -176,7 +176,7 @@ def preprocess( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities, dtype=xp.float32) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -224,14 +224,16 @@ def preprocess( # diffraction space coordinates self._xy_inds = np.argwhere(self._dp_mask) - self._kxy = (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) * xp.array( - self._reciprocal_sampling - )[None] + self._kxy = xp.asarray( + (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) + * xp.array(self._reciprocal_sampling)[None], + dtype=xp.float32, + ) self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) # Window function - x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1)[1:] + x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( xp.sin( @@ -242,7 +244,7 @@ def preprocess( ) ** 2 ) - y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1)[1:] + y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:] y -= (y[1] - y[0]) / 2 wy = ( xp.sin( @@ -259,7 +261,8 @@ def preprocess( ( self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], - ) + ), + dtype=xp.float32, ) self._window_pad[ self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -282,8 +285,8 @@ def preprocess( self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape) - self._stack_BF_no_window = xp.ones(stack_shape) + self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -306,12 +309,12 @@ def preprocess( ] = all_bfs elif normalize_order == 1: - x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) - y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones(xa.size), + xp.ones_like(xa), xa.ravel(), ya.ravel(), ) @@ -364,7 +367,11 @@ def preprocess( # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.asarray(qx, dtype=xp.float32) + qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.asarray(qy, dtype=xp.float32) + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") self._qx_shift = -2j * xp.pi * qxa self._qy_shift = -2j * xp.pi * qya @@ -399,7 +406,7 @@ def preprocess( del Gs else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2)) + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) self._stack_mean = xp.mean(self._stack_BF) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images @@ -686,7 +693,8 @@ def reconstruct( ( self._num_bf_images, (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1), - ) + ), + dtype=xp.float32, ) for ii in np.arange(regularizer_matrix_size[0] + 1): Bi = ( @@ -771,7 +779,7 @@ def reconstruct( # Sort by radial order, from center to outer edge inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1)) - shifts_update = xp.zeros((self._num_bf_images, 2)) + shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) for a1 in tqdmnd( xy_vals.shape[0], @@ -840,11 +848,19 @@ def reconstruct( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) + self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) + self._stack_BF = xp.asarray( + self._stack_BF, dtype=xp.float32 + ) # numpy fft upcasts? + self._stack_mask = xp.asarray( + self._stack_mask, dtype=xp.float32 + ) # numpy fft upcasts? + del Gs # Center the shifts From c9ac5db8de4221b24448a3615b83826e13f04862 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 14 Oct 2023 08:02:20 -0400 Subject: [PATCH 32/62] crop pattern option for all classes --- py4DSTEM/process/phase/iterative_base_class.py | 1 + .../iterative_mixedstate_multislice_ptychography.py | 12 ++++++++---- .../phase/iterative_mixedstate_ptychography.py | 11 ++++++++--- .../phase/iterative_multislice_ptychography.py | 11 ++++++++--- .../phase/iterative_overlap_magnetic_tomography.py | 11 ++++++++--- .../process/phase/iterative_overlap_tomography.py | 11 ++++++++--- .../phase/iterative_simultaneous_ptychography.py | 13 ++++++++++--- .../phase/iterative_singleslice_ptychography.py | 2 +- 8 files changed, 52 insertions(+), 20 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 209d33436..6ddfea643 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1099,6 +1099,7 @@ def _normalize_diffraction_intensities( Best fit vertical center of mass gradient crop_patterns: bool if True, crop patterns to avoid wrap around of patterns + when centering Returns ------- diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index ea10050dd..306f47f77 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -300,6 +300,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -357,6 +358,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -442,9 +445,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -525,7 +526,10 @@ def preprocess( bilinear=True, device=self._device, ) - + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( gpts=self._region_of_interest_shape, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 2a482faf0..658079c3e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -204,6 +204,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -261,6 +262,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -346,9 +349,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -429,6 +430,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index be6cbd6ed..382efedcd 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -305,6 +305,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -362,6 +363,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -447,9 +450,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -529,6 +530,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 5a1c5dde3..459b0ae8c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -431,6 +431,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -475,6 +476,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -592,9 +595,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -685,6 +686,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 0157fa422..bb3ee09c2 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -372,6 +372,7 @@ def preprocess( force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -416,6 +417,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -532,9 +535,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -625,6 +626,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 85e9a0b18..084a6fcb8 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -192,6 +192,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -246,6 +247,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -401,9 +404,7 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, + intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns ) # explicitly delete namescapes @@ -487,6 +488,7 @@ def preprocess( intensities_1, com_fitted_x_1, com_fitted_y_1, + crop_patterns ) # explicitly delete namescapes @@ -571,6 +573,7 @@ def preprocess( intensities_2, com_fitted_x_2, com_fitted_y_2, + crop_patterns ) # explicitly delete namescapes @@ -683,6 +686,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0cc6b65d5..0dc2cd053 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -247,7 +247,7 @@ def preprocess( Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- From 00991c99638f86c85cb1bd1571385b51b59cdc6a Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 14 Oct 2023 08:28:20 -0400 Subject: [PATCH 33/62] fix for gpu --- py4DSTEM/process/phase/iterative_base_class.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 6ddfea643..62cf3a3a1 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1172,6 +1172,7 @@ def _normalize_diffraction_intensities( amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) + amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity From 31af42990caa970e0fd36adf2f2df5076d9c491a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 00:42:19 -0700 Subject: [PATCH 34/62] cleaned up parallax descan --- py4DSTEM/process/phase/iterative_parallax.py | 44 ++++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index aa2a4a6b0..67815cd14 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -184,38 +184,46 @@ def preprocess( raise ValueError( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct if descan_correct: - from py4DSTEM.process.phase import DPCReconstruction - - dpc = DPCReconstruction( - energy=self._energy, - datacube=self._datacube, - verbose=False, - ).preprocess( - force_com_rotation=0, - force_com_transpose=False, - plot_center_of_mass=False, + ( + _, + _, + com_fitted_x, + com_fitted_y, + _, + _, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=None, + fit_function="plane", + com_shifts=None, + com_measured=None, ) - intensities_shifted = self._intensities.copy() + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + intensities = asnumpy(self._intensities) + intensities_shifted = np.zeros_like(intensities) + + center_x = np.mean(com_fitted_x) + center_y = np.mean(com_fitted_y) - center_x = np.mean(dpc._com_fitted_x) - center_y = np.mean(dpc._com_fitted_y) for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): intensity_shifted = get_shifted_ar( - self._intensities[rx, ry], - -dpc._com_fitted_x[rx, ry] + center_x, - -dpc._com_fitted_y[rx, ry] + center_y, + intensities[rx, ry], + -com_fitted_x[rx, ry] + center_x, + -com_fitted_y[rx, ry] + center_y, bilinear=True, device="cpu", ) intensities_shifted[rx, ry] = intensity_shifted - self._intensities = intensities_shifted - self._dp_mean = intensities_shifted.mean((0, 1)) + self._intensities = xp.asarray(intensities_shifted, xp.float32) + self._dp_mean = self._intensities.mean((0, 1)) # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) From 43289e3fc4c82419f477822e459f7bbf26fb1101 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 01:47:13 -0700 Subject: [PATCH 35/62] added support for float upsampling --- py4DSTEM/process/phase/iterative_parallax.py | 67 ++++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 67815cd14..5cec2e95a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -177,6 +177,9 @@ def preprocess( require_calibrations=True, ) + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) + self._scan_shape = np.array(self._intensities.shape[:2]) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -207,8 +210,10 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - center_x = np.mean(com_fitted_x) - center_y = np.mean(com_fitted_y) + # center_x = np.mean(com_fitted_x) + # center_y = np.mean(com_fitted_y) + + center_x, center_y = self._region_of_interest_shape / 2 for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): @@ -926,7 +931,7 @@ def reconstruct( def subpixel_alignment( self, - kde_upsample_factor=4, + kde_upsample_factor=None, kde_sigma=0.125, plot_upsampled_BF_comparison: bool = True, plot_upsampled_FFT_comparison: bool = False, @@ -955,8 +960,42 @@ def subpixel_alignment( xy_shifts = self._xy_shifts BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + self._DF_upsample_limit = np.max( + self._region_of_interest_shape / self._scan_shape + ) + self._BF_upsample_limit = ( + 2 * self._kr.max() / self._reciprocal_sampling[0] + ) / self._scan_shape.max() + if self._device == "gpu": + self._BF_upsample_limit = self._BF_upsample_limit.item() + + if kde_upsample_factor is None: + kde_upsample_factor = np.minimum( + self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit + ) + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + + if kde_upsample_factor < 1: + raise ValueError("kde_upsample_factor must be larger than 1") + + if kde_upsample_factor > self._DF_upsample_limit: + warnings.warn( + ( + "Requested upsampling factor exceeds " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}." + ), + UserWarning, + ) + self._kde_upsample_factor = kde_upsample_factor - pixel_output = BF_size * self._kde_upsample_factor + pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") pixel_size = pixel_output.prod() # shifted coordinates @@ -1031,12 +1070,12 @@ def subpixel_alignment( cmap = kwargs.pop("cmap", "magma") cropped_object = self._crop_padded_object(self._recon_BF) - upsampled_pad_x = ( - self._object_padding_px[0] * self._kde_upsample_factor // 2 - ) - upsampled_pad_y = ( - self._object_padding_px[1] * self._kde_upsample_factor // 2 - ) + upsampled_pad_x = np.round( + self._object_padding_px[0] * self._kde_upsample_factor / 2 + ).astype("int") + upsampled_pad_y = np.round( + self._object_padding_px[1] * self._kde_upsample_factor / 2 + ).astype("int") cropped_object_aligned = self.recon_BF_subpixel_aligned[ upsampled_pad_x:-upsampled_pad_x, upsampled_pad_y:-upsampled_pad_y, @@ -1072,8 +1111,12 @@ def subpixel_alignment( if plot_upsampled_FFT_comparison: recon_fft = xp.fft.fft2(self._recon_BF) recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) - pad_x = BF_size[0] * (self._kde_upsample_factor - 1) // 2 - pad_y = BF_size[1] * (self._kde_upsample_factor - 1) // 2 + pad_x = np.round( + BF_size[0] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_y = np.round( + BF_size[1] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") pad_recon_fft = asnumpy( xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) ) From 35d076fb05b0131a3aa012c885847ad4a601724e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 01:50:47 -0700 Subject: [PATCH 36/62] making descan correction the default --- py4DSTEM/process/phase/iterative_parallax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 5cec2e95a..ba751cf9b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -113,7 +113,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, - descan_correct: bool = False, + descan_correct: bool = True, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, From c06ca467a2628339d7a2b95d74c74aea2580b9f3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 02:10:17 -0700 Subject: [PATCH 37/62] removing redundant if statement --- py4DSTEM/process/phase/iterative_parallax.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index ba751cf9b..825366a5e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1225,12 +1225,8 @@ def aberration_fit( f"{self.aberration_A1y:.0f}) Ang" ) ) - if self.aberration_C1 > 0: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - else: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") # Plot the CTF comparison between experiment and fit if plot_CTF_compare: From 9be2e9a03368436f20e5eb14e06f218c828d263d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 03:37:43 -0700 Subject: [PATCH 38/62] removed separate ctf corrections and other subpixel improvements --- py4DSTEM/process/phase/iterative_parallax.py | 229 +++++-------------- 1 file changed, 63 insertions(+), 166 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 825366a5e..a8cbd0998 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1070,16 +1070,9 @@ def subpixel_alignment( cmap = kwargs.pop("cmap", "magma") cropped_object = self._crop_padded_object(self._recon_BF) - upsampled_pad_x = np.round( - self._object_padding_px[0] * self._kde_upsample_factor / 2 - ).astype("int") - upsampled_pad_y = np.round( - self._object_padding_px[1] * self._kde_upsample_factor / 2 - ).astype("int") - cropped_object_aligned = self.recon_BF_subpixel_aligned[ - upsampled_pad_x:-upsampled_pad_x, - upsampled_pad_y:-upsampled_pad_y, - ] + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) extent = [ 0, @@ -1109,7 +1102,6 @@ def subpixel_alignment( ax.set_xlabel("y [A]") if plot_upsampled_FFT_comparison: - recon_fft = xp.fft.fft2(self._recon_BF) recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) pad_x = np.round( BF_size[0] * (self._kde_upsample_factor - 1) / 2 @@ -1128,10 +1120,10 @@ def subpixel_alignment( ) reciprocal_extent = [ - 0, - self._reciprocal_sampling[1] * cropped_object_aligned.shape[1], - self._reciprocal_sampling[0] * cropped_object_aligned.shape[0], - 0, + -self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, + self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, + self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, + -self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, ] show( @@ -1312,8 +1304,9 @@ def aberration_correct( k_info_limit: float = None, k_info_power: float = 1.0, Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, + Wiener_signal_noise_ratio: float = 1.0, + Wiener_filter_low_only: bool = False, + upsampled: bool = True, **kwargs, ): """ @@ -1346,9 +1339,19 @@ def aberration_correct( ) ) + if upsampled and hasattr(self, "_kde_upsample_factor"): + im = self._recon_BF_subpixel_aligned + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + else: + upsampled = False + im = self._recon_BF + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + kx = xp.fft.fftfreq(im.shape[0], sx) + ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 # CTF @@ -1371,7 +1374,7 @@ def aberration_correct( CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr else: # CTF without tilt correction (beyond the parallax operator) @@ -1379,7 +1382,7 @@ def aberration_correct( CTF_corr[0, 0] = 0 # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr # if needed, add low pass filter output image if k_info_limit is not None: @@ -1391,131 +1394,6 @@ def aberration_correct( self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) self.recon_phase_corrected = asnumpy(self._recon_phase_corrected) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - # plotting - if plot_corrected_phase: - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - fig, ax = plt.subplots(figsize=figsize) - - cropped_object = self._crop_padded_object(self._recon_phase_corrected) - - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] - - ax.imshow( - cropped_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Parallax-Corrected Phase Image") - - def subpixel_aberration_correct( - self, - plot_corrected_phase: bool = True, - k_info_limit: float = None, - k_info_power: float = 1.0, - Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, - **kwargs, - ): - """ - CTF correction of the phase image using the measured defocus aberration. - - Parameters - ---------- - plot_corrected_phase: bool, optional - If True, the CTF-corrected phase is plotted - k_info_limit: float, optional - maximum allowed frequency in butterworth filter - k_info_power: float, optional - power of butterworth filter - Wiener_filter: bool, optional - Use Wiener filtering instead of CTF sign correction. - Wiener_signal_noise_ratio: float, optional - Signal to noise radio at k = 0 for Wiener filter - Wiener_filter_low_only: bool, optional - Apply Wiener filtering only to the CTF portions before the 1st CTF maxima. - """ - - xp = self._xp - asnumpy = self._asnumpy - - if not hasattr(self, "aberration_C1"): - raise ValueError( - ( - "CTF correction is meant to be ran after alignment and aberration fitting. " - "Please run the `reconstruct()` and `aberration_fit()` functions first." - ) - ) - - # Fourier coordinates - kx = xp.fft.fftfreq( - self._recon_BF_subpixel_aligned.shape[0], - self._scan_sampling[0] / self._kde_upsample_factor, - ) - ky = xp.fft.fftfreq( - self._recon_BF_subpixel_aligned.shape[1], - self._scan_sampling[1] / self._kde_upsample_factor, - ) - kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) - - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio - ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr - - else: - # CTF without tilt correction (beyond the parallax operator) - CTF_corr = xp.sign(sin_chi) - CTF_corr[0, 0] = 0 - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr - # if needed, add low pass filter output image - if k_info_limit is not None: - im_fft_corr /= 1 + (kra2**k_info_power) / ( - (k_info_limit) ** (2 * k_info_power) - ) - - # Output phase image - self._recon_phase_corrected_subpixel_aligned = xp.real( - xp.fft.ifft2(im_fft_corr) - ) - self.recon_phase_corrected_subpixel_aligned = asnumpy( - self._recon_phase_corrected_subpixel_aligned - ) - if self._device == "gpu": xp._default_memory_pool.free_all_blocks() xp.clear_memo() @@ -1528,17 +1406,13 @@ def subpixel_aberration_correct( fig, ax = plt.subplots(figsize=figsize) cropped_object = self._crop_padded_object( - self._recon_BF_subpixel_aligned, upsampled=True + self._recon_phase_corrected, upsampled=upsampled ) extent = [ 0, - self._scan_sampling[1] - / self._kde_upsample_factor - * cropped_object.shape[1], - self._scan_sampling[0] - / self._kde_upsample_factor - * cropped_object.shape[0], + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], 0, ] @@ -1551,7 +1425,7 @@ def subpixel_aberration_correct( ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Parallax-Corrected Phase Image Subpixel Aligned") + ax.set_title("Parallax-Corrected Phase Image") def depth_section( self, @@ -1716,12 +1590,19 @@ def _crop_padded_object( asnumpy = self._asnumpy - pad_x = self._object_padding_px[0] // 2 - remaining_padding - pad_y = self._object_padding_px[1] // 2 - remaining_padding - if upsampled: - pad_x *= self._kde_upsample_factor - pad_y *= self._kde_upsample_factor + pad_x = np.round( + self._object_padding_px[0] / 2 * self._kde_upsample_factor + ).astype("int") + pad_y = np.round( + self._object_padding_px[1] / 2 * self._kde_upsample_factor + ).astype("int") + else: + pad_x = self._object_padding_px[0] // 2 + pad_y = self._object_padding_px[1] // 2 + + pad_x -= remaining_padding + pad_y -= remaining_padding return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) @@ -1730,6 +1611,7 @@ def _visualize_figax( fig, ax, remaining_padding: int = 0, + upsampled: bool = False, **kwargs, ): """ @@ -1748,14 +1630,29 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + cropped_object = self._crop_padded_object( + self._recon_BF, remaining_padding, upsampled + ) + + if upsampled: + extent = [ + 0, + self._scan_sampling[1] + * cropped_object.shape[1] + / self._kde_upsample_factor, + self._scan_sampling[0] + * cropped_object.shape[0] + / self._kde_upsample_factor, + 0, + ] - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + else: + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] ax.imshow( cropped_object, From ab76946469b565a1bb1c868a21f869a84999971f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 04:38:00 -0700 Subject: [PATCH 39/62] added read-write functionality to parralax --- py4DSTEM/process/phase/iterative_parallax.py | 134 +++++++++++++++++-- 1 file changed, 125 insertions(+), 9 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index a8cbd0998..ff6fb52af 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from emdfile import Custom, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec from py4DSTEM import DataCube from py4DSTEM.preprocess.utils import get_shifted_ar @@ -75,6 +75,8 @@ def __init__( else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_save_defaults() + # Data self._datacube = datacube @@ -88,9 +90,68 @@ def __init__( def to_h5(self, group): """ Wraps datasets and metadata to write in emdfile classes, - notably ... + notably the (subpixel-)aligned BF. """ - raise NotImplementedError() + # instantiation metadata + self.metadata = Metadata( + name="instantiation_metadata", + data={ + "energy": self._energy, + "verbose": self._verbose, + "device": self._device, + "object_padding_px": self._object_padding_px, + "name": self.name, + }, + ) + + # preprocessing metadata + self.metadata = Metadata( + name="preprocess_metadata", + data={ + "scan_sampling": self._scan_sampling, + "wavelength": self._wavelength, + }, + ) + + # reconstruction metadata + recon_metadata = {"reconstruction_error": float(self._recon_error)} + + if hasattr(self, "aberration_C1"): + recon_metadata |= { + "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_C1": self.aberration_C1, + "aberration_A1x": self.aberration_A1x, + "aberration_A1y": self.aberration_A1y, + } + + if hasattr(self, "_kde_upsample_factor"): + recon_metadata |= { + "kde_upsample_factor": self._kde_upsample_factor, + } + self._subpixel_aligned_BF_emd = Array( + name="subpixel_aligned_BF", + data=self._asnumpy(self._recon_BF_subpixel_aligned), + ) + + self.metadata = Metadata( + name="reconstruction_metadata", + data=recon_metadata, + ) + + self._aligned_BF_emd = Array( + name="aligned_BF", + data=self._asnumpy(self._recon_BF), + ) + + # datacube + if self._save_datacube: + self.metadata = self._datacube.calibration + Custom.to_h5(self, group) + else: + dc = self._datacube + self._datacube = None + Custom.to_h5(self, group) + self._datacube = dc @classmethod def _get_constructor_args(cls, group): @@ -98,14 +159,67 @@ def _get_constructor_args(cls, group): Returns a dictionary of arguments/values to pass to the class' __init__ function """ - raise NotImplementedError() + # Get data + dict_data = cls._get_emd_attr_data(cls, group) + + # Get metadata dictionaries + instance_md = _read_metadata(group, "instantiation_metadata") + + # Fix calibrations bug + if "_datacube" in dict_data: + calibrations_dict = _read_metadata(group, "calibration")._params + cal = Calibration() + cal._params.update(calibrations_dict) + dc = dict_data["_datacube"] + dc.calibration = cal + else: + dc = None + + # Populate args and return + kwargs = { + "datacube": dc, + "energy": instance_md["energy"], + "verbose": instance_md["verbose"], + "device": instance_md["device"], + "object_padding_px": instance_md["object_padding_px"], + "name": instance_md["name"], + } + + return kwargs def _populate_instance(self, group): """ Sets post-initialization properties, notably some preprocessing meta optional; during read, this method is run after object instantiation. """ - raise NotImplementedError() + + xp = self._xp + + # Preprocess metadata + preprocess_md = _read_metadata(group, "preprocess_metadata") + self._scan_sampling = preprocess_md["scan_sampling"] + self._wavelength = preprocess_md["wavelength"] + + # Reconstruction metadata + reconstruction_md = _read_metadata(group, "reconstruction_metadata") + self._recon_error = reconstruction_md["reconstruction_error"] + + # Data + dict_data = Custom._get_emd_attr_data(Custom, group) + + if "aberration_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.aberration_C1 = reconstruction_md["aberration_C1"] + self.aberration_A1x = reconstruction_md["aberration_A1x"] + self.aberration_A1y = reconstruction_md["aberration_A1y"] + + if "kde_upsample_factor" in reconstruction_md.keys: + self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] + self._recon_BF_subpixel_aligned = xp.asarray( + dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32 + ) + + self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) def preprocess( self, @@ -1630,11 +1744,11 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object( - self._recon_BF, remaining_padding, upsampled - ) - if upsampled: + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, remaining_padding, upsampled + ) + extent = [ 0, self._scan_sampling[1] @@ -1647,6 +1761,8 @@ def _visualize_figax( ] else: + cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + extent = [ 0, self._scan_sampling[1] * cropped_object.shape[1], From 06e18376b820776e9f8a46f4f80095a94f474743 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 16 Oct 2023 12:54:58 -0700 Subject: [PATCH 40/62] Starting on CTF fitting --- py4DSTEM/process/phase/iterative_parallax.py | 208 ++++++++++++------- 1 file changed, 128 insertions(+), 80 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index ff6fb52af..c855c1451 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1267,9 +1267,11 @@ def subpixel_alignment( def aberration_fit( self, + fit_thon_rings = True, + fit_upsampled_fft = True, plot_CTF_compare: bool = False, - plot_dk: float = 0.005, - plot_k_sigma: float = 0.02, + # plot_dk: float = 0.005, + # plot_k_sigma: float = 0.02, ): """ Fit aberrations to the measured image shifts. @@ -1277,17 +1279,27 @@ def aberration_fit( Parameters ---------- plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies + If True, the fitted CTF is plotted against the reconstructed frequencies. + fit_thon_rings: bool + Set to True to directly fit aberrations in the FFT of the upsampled BF + image (if available). Note that this method relies on visible zero + crossings in the FFT, and will not work if they are not present. + fit_upsampled_fft: bool + If True, we aberration fit is performed on the upsampled BF image. + This option does nothing if fit_thon_rings is not True. plot_dk: float, optional Reciprocal bin-size for polar-averaged FFT plot_k_sigma: float, optional sigma to gaussian blur polar-averaged FFT by + """ xp = self._xp asnumpy = self._asnumpy gaussian_filter = self._gaussian_filter + # initial aberration fit + # Convert real space shifts to Angstroms self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) @@ -1316,6 +1328,42 @@ def aberration_fit( xp._default_memory_pool.free_all_blocks() xp.clear_memo() + # Refinement using Thon rings + if fit_thon_rings: + if fit_upsampled_fft: + # Get mean FFT of BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + print(self._kde_upsample_factor) + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor + else: + # Get mean FFT of upsampled BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + + # FFT coordinates + qx = fft + + # weights for fits + + # #zero origin pixel + # im_fft[0,0] = 0 + + + print(im_fft.shape) + + fig,ax = plt.subplots(figsize=(6,6)) + ax.imshow( + np.fft.fftshift(im_fft)**0.5, + ) + + + # Print results if self._verbose: print( @@ -1334,83 +1382,83 @@ def aberration_fit( print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: - # Get polar mean from FFT of BF reconstruction - im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - k_max = xp.max(kra) / np.sqrt(2.0) - k_num_bins = int(xp.ceil(k_max / plot_dk)) - k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # histogram - k_ind = kra / plot_dk - kf = np.floor(k_ind).astype("int") - dk = k_ind - kf - sub = kf <= k_num_bins - hist_exp = xp.bincount( - kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins - ) - hist_norm = xp.bincount( - kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins - ) - sub = kf <= k_num_bins - 1 - - hist_exp += xp.bincount( - kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins - ) - hist_norm += xp.bincount( - kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins - ) - - # KDE and normalizing - k_sigma = plot_dk / plot_k_sigma - hist_exp[0] = 0.0 - hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - hist_exp /= hist_norm - - # CTF comparison - CTF_fit = xp.sin( - (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 - ) - - # plotting input - log scale - min_hist_val = xp.max(hist_exp) * 1e-3 - hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - hist_plot -= xp.min(hist_plot) - hist_plot /= xp.max(hist_plot) - - hist_plot = asnumpy(hist_plot) - k_bins = asnumpy(k_bins) - CTF_fit = asnumpy(CTF_fit) - - fig, ax = plt.subplots(figsize=(8, 4)) - - ax.fill_between( - k_bins, - hist_plot, - color=(0.7, 0.7, 0.7, 1), - ) - - ax.plot( - k_bins, - np.clip(CTF_fit, 0.0, np.inf), - color=(1, 0, 0, 1), - linewidth=2, - ) - ax.plot( - k_bins, - np.clip(-CTF_fit, 0.0, np.inf), - color=(0, 0.5, 1, 1), - linewidth=2, - ) - ax.set_xlim([0, k_bins[-1]]) - ax.set_ylim([0, 1.05]) + # # Plot the CTF comparison between experiment and fit + # if plot_CTF_compare: + # # Get polar mean from FFT of BF reconstruction + # im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) + + # # coordinates + # kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) + # ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + # kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) + # k_max = xp.max(kra) / np.sqrt(2.0) + # k_num_bins = int(xp.ceil(k_max / plot_dk)) + # k_bins = xp.arange(k_num_bins + 1) * plot_dk + + # # histogram + # k_ind = kra / plot_dk + # kf = np.floor(k_ind).astype("int") + # dk = k_ind - kf + # sub = kf <= k_num_bins + # hist_exp = xp.bincount( + # kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins + # ) + # hist_norm = xp.bincount( + # kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins + # ) + # sub = kf <= k_num_bins - 1 + + # hist_exp += xp.bincount( + # kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins + # ) + # hist_norm += xp.bincount( + # kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins + # ) + + # # KDE and normalizing + # k_sigma = plot_dk / plot_k_sigma + # hist_exp[0] = 0.0 + # hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") + # hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") + # hist_exp /= hist_norm + + # # CTF comparison + # CTF_fit = xp.sin( + # (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 + # ) + + # # plotting input - log scale + # min_hist_val = xp.max(hist_exp) * 1e-3 + # hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) + # hist_plot -= xp.min(hist_plot) + # hist_plot /= xp.max(hist_plot) + + # hist_plot = asnumpy(hist_plot) + # k_bins = asnumpy(k_bins) + # CTF_fit = asnumpy(CTF_fit) + + # fig, ax = plt.subplots(figsize=(8, 4)) + + # ax.fill_between( + # k_bins, + # hist_plot, + # color=(0.7, 0.7, 0.7, 1), + # ) + + # ax.plot( + # k_bins, + # np.clip(CTF_fit, 0.0, np.inf), + # color=(1, 0, 0, 1), + # linewidth=2, + # ) + # ax.plot( + # k_bins, + # np.clip(-CTF_fit, 0.0, np.inf), + # color=(0, 0.5, 1, 1), + # linewidth=2, + # ) + # ax.set_xlim([0, k_bins[-1]]) + # ax.set_ylim([0, 1.05]) def aberration_correct( self, From 593f07d4c7417167c681f84e628b937c29c36e67 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 16 Oct 2023 18:00:11 -0700 Subject: [PATCH 41/62] Adding more parts of parallax CTF fitting --- py4DSTEM/process/phase/iterative_parallax.py | 178 +++++++++++++++++-- 1 file changed, 160 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index c855c1451..f7618c749 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -18,6 +18,8 @@ from py4DSTEM.visualize import show from scipy.linalg import polar from scipy.special import comb +from scipy.optimize import curve_fit +from scipy.signal import medfilt2d try: import cupy as cp @@ -1269,6 +1271,10 @@ def aberration_fit( self, fit_thon_rings = True, fit_upsampled_fft = True, + aber_order_max = 2, + q_power_fit = 0.0, + medfilt_size = None, + maxfev = None, plot_CTF_compare: bool = False, # plot_dk: float = 0.005, # plot_k_sigma: float = 0.02, @@ -1287,10 +1293,14 @@ def aberration_fit( fit_upsampled_fft: bool If True, we aberration fit is performed on the upsampled BF image. This option does nothing if fit_thon_rings is not True. + aber_order_max: int + Max radial order for fitting of aberrations. + q_power_fit: float + q power fitting weight. plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT + Reciprocal bin-size for polar-averaged FFT. plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by + sigma to gaussian blur polar-averaged FFT by. """ @@ -1335,7 +1345,6 @@ def aberration_fit( im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) # coordinates - print(self._kde_upsample_factor) q_pixel_size = np.array(self._reciprocal_sampling) \ / self._kde_upsample_factor else: @@ -1347,22 +1356,130 @@ def aberration_fit( # FFT coordinates - qx = fft + qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) + qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha2 = qr2 * self._wavelength**2 + self.theta = np.arctan2(qy[None,:],qx[:,None]) # weights for fits - - # #zero origin pixel - # im_fft[0,0] = 0 - - - print(im_fft.shape) - - fig,ax = plt.subplots(figsize=(6,6)) - ax.imshow( - np.fft.fftshift(im_fft)**0.5, + self.q_weight = qr2 ** (q_power_fit/2) + + # Aberration coefs + mn = [] + for m in range(0,aber_order_max//2+1): + n_max = np.floor(aber_order_max-2*m).astype('int') + + for n in range(0,n_max+1): + if m + n > 1 or (m > 0 and n == 0): + if n == 0: + mn.append([m,n,0]) + else: + mn.append([m,n,0]) + mn.append([m,n,1]) + self.aber_mn = np.array(mn) + + # Aberration basis + self.aber_num = self.aber_mn.shape[0] + self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) + # self.aber_basis[:,0] = self.alpha2.ravel() + for a0 in range(self.aber_num): + if self.aber_mn[a0,1] == 0: + # Radially symmetric basis + self.aber_basis[:,a0] = self.alpha2.ravel()**self.aber_mn[a0,0] + elif self.aber_mn[a0,2] == 0: + # cos coef + self.aber_basis[:,a0] = \ + self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ + np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + else: + # sin coef + self.aber_basis[:,a0] = \ + self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ + np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # fitting image + im_fit = im_fft * self.q_weight + if medfilt_size is not None: + im_fit = np.fft.ifftshift(medfilt2d( + np.fft.fftshift(im_fit), + medfilt_size)) + + # initial coefs + int_max = np.max(im_fit) + sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) + coefs = np.zeros(5 + self.aber_num) + lb = np.zeros(5 + self.aber_num) + ub = np.ones(5 + self.aber_num) * np.inf + coefs[0] = 1e-3 + coefs[1] = int_max * 0.1 + coefs[2] = sigma_init + coefs[3] = int_max * 0.9 + coefs[4] = sigma_init + lb[5:] = -np.inf + # initial C1 value (defocus) + ind = np.argmin( + np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + ) + C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength + coefs[ind + 5] = C1_dimensionless + + # Fitting mask + fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) + basis_masked = self.aber_basis[fit_mask.ravel(),:] + + # Define fitting functions + + def calc_CTF_mag(alpha2, *coefs): + int0 = coefs[0] + int1 = coefs[1] + sigma1 = coefs[2] + int_env = coefs[3] + sigma_env = coefs[4] + + im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + chi = np.zeros_like(im_CTF_mag) + for a0 in range(5,len(coefs)): + chi += coefs[a0] * self.aber_basis[:,a0-5] + return im_CTF_mag + np.abs(np.sin(chi)) * env + + def calc_CTF_mag_masked(alpha2, *coefs): + int0 = coefs[0] + int1 = coefs[1] + sigma1 = coefs[2] + int_env = coefs[3] + sigma_env = coefs[4] + + im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + chi = np.zeros_like(im_CTF_mag) + for a0 in range(5,len(coefs)): + chi += coefs[a0] * basis_masked[:,a0-5] + return im_CTF_mag + np.abs(np.sin(chi)) * env + + # Refine aberration coefficients + if maxfev is None: + coefs = np.array( + curve_fit( + calc_CTF_mag_masked, + self.alpha2[fit_mask], + im_fit[fit_mask], + p0 = tuple(coefs), + bounds = (lb,ub), + )[0] + ) + else: + coefs = np.array( + curve_fit( + calc_CTF_mag_masked, + self.alpha2[fit_mask], + im_fit[fit_mask], + p0 = tuple(coefs), + bounds = (lb,ub), + maxfev = maxfev, + )[0] ) - - # Print results if self._verbose: @@ -1382,8 +1499,33 @@ def aberration_fit( print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - # # Plot the CTF comparison between experiment and fit - # if plot_CTF_compare: + # Plot the CTF comparison between experiment and fit + if plot_CTF_compare: + im_test = np.reshape(calc_CTF_mag(self.alpha2.ravel(), *coefs), im_fit.shape) + + fig,ax = plt.subplots(figsize=(12,6)) + ax.imshow( + np.hstack(( + np.fft.fftshift(im_fit), + np.fft.fftshift(im_test), + )), + vmin = np.min(im_test[fit_mask]), + vmax = np.max(im_test[fit_mask]), + cmap = 'gray', + ) + + # ax.imshow( + # im_plot / np.max(im_plot), + # vmin = 0, + # vmax = 1, + # cmap = 'gray', + # ) + # ax.imshow( \ + # np.fft.fftshift( + # np.mod(np.reshape(self.aber_basis[:,2],im_fft.shape)+np.pi,2*np.pi)-np.pi + # )) + + # # Get polar mean from FFT of BF reconstruction # im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) From e5d7425b29fcb613057ca7d8a90e7671d2c04c38 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 17 Oct 2023 03:13:26 -0700 Subject: [PATCH 42/62] added chroma_boost for show_complex --- py4DSTEM/visualize/vis_special.py | 33 ++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 6dd980bce..f7beec241 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -938,12 +938,13 @@ def show_selected_dps( ) -def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value power (float) : power to raise amplitude to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.5) """ amp = np.abs(complex_data) phase = np.angle(complex_data) @@ -974,7 +975,7 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): amp = ((amp - vmin) / vmax).clip(1e-16, 1) J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff - C = np.where(J < 61.5, 98 * J / 123, 1400 / 11 - 14 * J / 11) # Min uniform chroma + C = np.minimum(chroma_boost * 98 * J / 123, 110) h = np.rad2deg(phase) + 180 JCh = np.stack((J, C, h), axis=-1) @@ -983,16 +984,17 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): return rgb -def add_colorbar_arg(cax, c=49, j=61.5): +def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5): """ cax : axis to add cbar to - c : constant chroma value - j : constant luminance value + chroma_boost (float): boosts chroma for higher-contrast (~1-2.25) + c (float) : constant chroma value + j (float) : constant luminance value """ h = np.linspace(0, 360, 256, endpoint=False) J = np.full_like(h, j) - C = np.full_like(h, c) + C = np.full_like(h, np.minimum(c * chroma_boost, 110)) JCh = np.stack((J, C, h), axis=-1) rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) @@ -1012,12 +1014,13 @@ def show_complex( ar_complex, vmin=None, vmax=None, + power=None, + chroma_boost=1, cbar=True, scalebar=False, pixelunits="pixels", pixelsize=1, returnfig=False, - power=None, **kwargs ): """ @@ -1030,12 +1033,13 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels + power (float,optional) : power to raise amplitude to + chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25) cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - power (float,optional) : power to raise amplitude to Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -1050,7 +1054,7 @@ def show_complex( if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): rgb = [ - Complex2RGB(ar, vmin, vmax, power=power) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for sublist in ar_complex for ar in sublist ] @@ -1058,7 +1062,10 @@ def show_complex( W = len(ar_complex[0]) else: - rgb = [Complex2RGB(ar, vmin, vmax, power=power) for ar in ar_complex] + rgb = [ + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) + for ar in ar_complex + ] if len(rgb[0].shape) == 4: H = len(ar_complex) W = rgb[0].shape[0] @@ -1067,7 +1074,9 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, power=power) + rgb = Complex2RGB( + ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost + ) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -1127,7 +1136,7 @@ def show_complex( else: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) fig.tight_layout() From 62c4cae86ad5ad63242b4f22f1e742a0b80746bc Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 10:46:11 -0700 Subject: [PATCH 43/62] Working on CTF --- py4DSTEM/process/phase/iterative_parallax.py | 383 ++++++++++--------- 1 file changed, 209 insertions(+), 174 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index f7618c749..4bfd265f9 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1269,15 +1269,12 @@ def subpixel_alignment( def aberration_fit( self, - fit_thon_rings = True, - fit_upsampled_fft = True, - aber_order_max = 2, - q_power_fit = 0.0, - medfilt_size = None, - maxfev = None, + fit_CTF_FFT = True, + fit_CTF_threshold = 0.1, + fit_upsampled_FFT = True, + fit_aber_order_max = 2, + fit_maxfev = None, plot_CTF_compare: bool = False, - # plot_dk: float = 0.005, - # plot_k_sigma: float = 0.02, ): """ Fit aberrations to the measured image shifts. @@ -1286,21 +1283,17 @@ def aberration_fit( ---------- plot_CTF_compare: bool, optional If True, the fitted CTF is plotted against the reconstructed frequencies. - fit_thon_rings: bool + fit_CTF_FFT: bool Set to True to directly fit aberrations in the FFT of the upsampled BF image (if available). Note that this method relies on visible zero crossings in the FFT, and will not work if they are not present. - fit_upsampled_fft: bool + fit_upsampled_FFT: bool If True, we aberration fit is performed on the upsampled BF image. This option does nothing if fit_thon_rings is not True. - aber_order_max: int + fit_aber_order_max: int Max radial order for fitting of aberrations. - q_power_fit: float - q power fitting weight. - plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT. - plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by. + ctf_threshold: float + CTF fitting minimizes value at CTF zero crossings (Thon ring minima). """ @@ -1338,182 +1331,204 @@ def aberration_fit( xp._default_memory_pool.free_all_blocks() xp.clear_memo() - # Refinement using Thon rings - if fit_thon_rings: - if fit_upsampled_fft: - # Get mean FFT of BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor + # Aberration coefs + mn = [] + for m in range(0,fit_aber_order_max//2+1): + n_max = np.floor(fit_aber_order_max-2*m).astype('int') + + for n in range(0,n_max+1): + if m + n > 1 or (m > 0 and n == 0): + if n == 0: + mn.append([m,n,0]) + else: + mn.append([m,n,0]) + mn.append([m,n,1]) + self.aber_mn = np.array(mn) + + # Aberration basis + self.aber_num = self.aber_mn.shape[0] + self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) + for a0 in range(self.aber_num): + if self.aber_mn[a0,1] == 0: + # Radially symmetric basis + self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) + elif self.aber_mn[a0,2] == 0: + # cos coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.cos(self.aber_mn[a0,1] * self.theta.ravel()) else: - # Get mean FFT of upsampled BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF)) + # sin coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # CTF function + def calc_CTF(alpha, *coefs): + chi = np.zeros_like(alpha.ravel()) + for a0 in range(len(coefs)): + chi += coefs[a0] * self.aber_basis[:,a0] + return np.reshape(chi, alpha.shape) + + if fit_upsampled_FFT: + # Get mean FFT of BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor + else: + # Get mean FFT of upsampled BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + # FFT coordinates + qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) + qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha = np.sqrt(qr2) * self._wavelength + self.theta = np.arctan2(qy[None,:],qx[:,None]) + + # initial coefficients and plotting intensity range mask + C1_dimensionless = self.aberration_C1 * 0.5 * self._wavelength + coefs = np.zeros(self.aber_num) + ind = np.argmin( + np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + ) + coefs[ind] = C1_dimensionless + plot_mask = self.alpha > np.sqrt(np.pi/np.abs(C1_dimensionless)) + angular_mask = np.cos(4.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.5 - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + # Refinement using CTF fitting / Thon rings + if fit_CTF_FFT: + pass - # FFT coordinates - qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) - qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) - qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha2 = qr2 * self._wavelength**2 - self.theta = np.arctan2(qy[None,:],qx[:,None]) - - # weights for fits - self.q_weight = qr2 ** (q_power_fit/2) - - # Aberration coefs - mn = [] - for m in range(0,aber_order_max//2+1): - n_max = np.floor(aber_order_max-2*m).astype('int') - - for n in range(0,n_max+1): - if m + n > 1 or (m > 0 and n == 0): - if n == 0: - mn.append([m,n,0]) - else: - mn.append([m,n,0]) - mn.append([m,n,1]) - self.aber_mn = np.array(mn) - - # Aberration basis - self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) - # self.aber_basis[:,0] = self.alpha2.ravel() - for a0 in range(self.aber_num): - if self.aber_mn[a0,1] == 0: - # Radially symmetric basis - self.aber_basis[:,a0] = self.alpha2.ravel()**self.aber_mn[a0,0] - elif self.aber_mn[a0,2] == 0: - # cos coef - self.aber_basis[:,a0] = \ - self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ - np.cos(self.aber_mn[a0,1] * self.theta.ravel()) - else: - # sin coef - self.aber_basis[:,a0] = \ - self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ - np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # fitting image - im_fit = im_fft * self.q_weight - if medfilt_size is not None: - im_fit = np.fft.ifftshift(medfilt2d( - np.fft.fftshift(im_fit), - medfilt_size)) - - # initial coefs - int_max = np.max(im_fit) - sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) - coefs = np.zeros(5 + self.aber_num) - lb = np.zeros(5 + self.aber_num) - ub = np.ones(5 + self.aber_num) * np.inf - coefs[0] = 1e-3 - coefs[1] = int_max * 0.1 - coefs[2] = sigma_init - coefs[3] = int_max * 0.9 - coefs[4] = sigma_init - lb[5:] = -np.inf - # initial C1 value (defocus) - ind = np.argmin( - np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - ) - C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength - coefs[ind + 5] = C1_dimensionless + # im_fit = im_fft * self.q_weight + # if medfilt_size is not None: + # im_fit = np.fft.ifftshift(medfilt2d( + # np.fft.fftshift(im_fit), + # medfilt_size)) + + # # initial coefs + # int_max = np.max(im_fit) + # sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) + # coefs = np.zeros(5 + self.aber_num) + # lb = np.zeros(5 + self.aber_num) + # ub = np.ones(5 + self.aber_num) * np.inf + # coefs[0] = 1e-3 + # coefs[1] = int_max * 0.1 + # coefs[2] = sigma_init + # coefs[3] = int_max * 0.9 + # coefs[4] = sigma_init + # lb[5:] = -np.inf + # # initial C1 value (defocus) + # ind = np.argmin( + # np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + # ) + # C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength + # coefs[ind + 5] = C1_dimensionless # Fitting mask - fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) - basis_masked = self.aber_basis[fit_mask.ravel(),:] + # fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) + # basis_masked = self.aber_basis[fit_mask.ravel(),:] # Define fitting functions - def calc_CTF_mag(alpha2, *coefs): - int0 = coefs[0] - int1 = coefs[1] - sigma1 = coefs[2] - int_env = coefs[3] - sigma_env = coefs[4] - - im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - chi = np.zeros_like(im_CTF_mag) - for a0 in range(5,len(coefs)): - chi += coefs[a0] * self.aber_basis[:,a0-5] - return im_CTF_mag + np.abs(np.sin(chi)) * env - - def calc_CTF_mag_masked(alpha2, *coefs): - int0 = coefs[0] - int1 = coefs[1] - sigma1 = coefs[2] - int_env = coefs[3] - sigma_env = coefs[4] - - im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - chi = np.zeros_like(im_CTF_mag) - for a0 in range(5,len(coefs)): - chi += coefs[a0] * basis_masked[:,a0-5] - return im_CTF_mag + np.abs(np.sin(chi)) * env - - # Refine aberration coefficients - if maxfev is None: - coefs = np.array( - curve_fit( - calc_CTF_mag_masked, - self.alpha2[fit_mask], - im_fit[fit_mask], - p0 = tuple(coefs), - bounds = (lb,ub), - )[0] - ) - else: - coefs = np.array( - curve_fit( - calc_CTF_mag_masked, - self.alpha2[fit_mask], - im_fit[fit_mask], - p0 = tuple(coefs), - bounds = (lb,ub), - maxfev = maxfev, - )[0] - ) - - # Print results - if self._verbose: - print( - ( - "Rotation of Q w.r.t. R = " - f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" - ) - ) - print( - ( - "Astigmatism (A1x,A1y) = (" - f"{self.aberration_A1x:.0f}," - f"{self.aberration_A1y:.0f}) Ang" - ) - ) - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + # def calc_CTF_mag(alpha2, *coefs): + # int0 = coefs[0] + # int1 = coefs[1] + # sigma1 = coefs[2] + # int_env = coefs[3] + # sigma_env = coefs[4] + + # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + # chi = np.zeros_like(im_CTF_mag) + # for a0 in range(5,len(coefs)): + # chi += coefs[a0] * self.aber_basis[:,a0-5] + # return im_CTF_mag + np.abs(np.sin(chi)) * env + + # def calc_CTF_mag_masked(alpha2, *coefs): + # int0 = coefs[0] + # int1 = coefs[1] + # sigma1 = coefs[2] + # int_env = coefs[3] + # sigma_env = coefs[4] + + # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + # chi = np.zeros_like(im_CTF_mag) + # for a0 in range(5,len(coefs)): + # chi += coefs[a0] * basis_masked[:,a0-5] + # return im_CTF_mag + np.abs(np.sin(chi)) * env + + # # Refine aberration coefficients + # if maxfev is None: + # coefs = np.array( + # curve_fit( + # calc_CTF_mag_masked, + # self.alpha2[fit_mask], + # im_fit[fit_mask], + # p0 = tuple(coefs), + # bounds = (lb,ub), + # )[0] + # ) + # else: + # coefs = np.array( + # curve_fit( + # calc_CTF_mag_masked, + # self.alpha2[fit_mask], + # im_fit[fit_mask], + # p0 = tuple(coefs), + # bounds = (lb,ub), + # maxfev = maxfev, + # )[0] + # ) # Plot the CTF comparison between experiment and fit if plot_CTF_compare: - im_test = np.reshape(calc_CTF_mag(self.alpha2.ravel(), *coefs), im_fit.shape) + # Generate FFT plotting image + int_range = (np.min(im_fft[plot_mask]),np.max(im_fft[plot_mask])) + int_range = (int_range[0],(int_range[1]-int_range[0])*0.5 + int_range[0]) + im_scale = np.clip( + (np.fft.fftshift(im_fft) - int_range[0]) / (int_range[1] - int_range[0]), + 0,1) + # im_scale = im_scale**0.5 + im_plot = np.tile(im_scale[:,:,None],(1,1,3)) + + # Add CTF zero crossings + im_CTF = calc_CTF(self.alpha,*coefs) + # im_CTF = np.sin(im_CTF)**2 + # im_CTF = np.fft.fftshift(im_CTF) + # print(np.max(im_CTF)) + im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold + im_CTF[np.logical_not(plot_mask)] = 0 + im_CTF = np.fft.fftshift(im_CTF * angular_mask) + im_plot[:,:,0] += im_CTF + im_plot[:,:,1] -= im_CTF + im_plot[:,:,2] -= im_CTF + im_plot = np.clip(im_plot,0,1) fig,ax = plt.subplots(figsize=(12,6)) ax.imshow( - np.hstack(( - np.fft.fftshift(im_fit), - np.fft.fftshift(im_test), - )), - vmin = np.min(im_test[fit_mask]), - vmax = np.max(im_test[fit_mask]), - cmap = 'gray', + im_plot, + # np.fft.fftshift(np.reshape(self.aber_basis[:,1],im_CTF.shape)) + # angular_mask, + # np.hstack(( + # im_scale, + # im_CTF + # )) + # im_ctf ) + # ax.imshow( # im_plot / np.max(im_plot), # vmin = 0, @@ -1602,6 +1617,26 @@ def calc_CTF_mag_masked(alpha2, *coefs): # ax.set_xlim([0, k_bins[-1]]) # ax.set_ylim([0, 1.05]) + + # Print results + if self._verbose: + print( + ( + "Rotation of Q w.r.t. R = " + f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" + ) + ) + print( + ( + "Astigmatism (A1x,A1y) = (" + f"{self.aberration_A1x:.0f}," + f"{self.aberration_A1y:.0f}) Ang" + ) + ) + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + + def aberration_correct( self, plot_corrected_phase: bool = True, From ea070f85b8196828ff597f8479dad4fb910fb87c Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 15:52:13 -0700 Subject: [PATCH 44/62] It works! --- py4DSTEM/process/phase/iterative_parallax.py | 120 +++++++++++++------ 1 file changed, 85 insertions(+), 35 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 4bfd265f9..243a8ec5b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -18,7 +18,7 @@ from py4DSTEM.visualize import show from scipy.linalg import polar from scipy.special import comb -from scipy.optimize import curve_fit +from scipy.optimize import curve_fit, minimize from scipy.signal import medfilt2d try: @@ -1270,10 +1270,12 @@ def subpixel_alignment( def aberration_fit( self, fit_CTF_FFT = True, - fit_CTF_threshold = 0.1, + fit_CTF_threshold = 0.25, fit_upsampled_FFT = True, fit_aber_order_max = 2, - fit_maxfev = None, + fit_max_num_rings = 6, + fit_power_alpha = 2.0, + # fit_maxfev = None, plot_CTF_compare: bool = False, ): """ @@ -1301,7 +1303,7 @@ def aberration_fit( asnumpy = self._asnumpy gaussian_filter = self._gaussian_filter - # initial aberration fit + ### initial aberration fit ### # Convert real space shifts to Angstroms self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) @@ -1331,6 +1333,30 @@ def aberration_fit( xp._default_memory_pool.free_all_blocks() xp.clear_memo() + + ### FFT fitting / plotting code ### + + if fit_upsampled_FFT: + # Get mean FFT of BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor + else: + # Get mean FFT of upsampled BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + # FFT coordinates + qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) + qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha = np.sqrt(qr2) * self._wavelength + self.theta = np.arctan2(qy[None,:],qx[:,None]) + # Aberration coefs mn = [] for m in range(0,fit_aber_order_max//2+1): @@ -1347,7 +1373,7 @@ def aberration_fit( # Aberration basis self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) + self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) for a0 in range(self.aber_num): if self.aber_mn[a0,1] == 0: # Radially symmetric basis @@ -1370,40 +1396,56 @@ def calc_CTF(alpha, *coefs): chi += coefs[a0] * self.aber_basis[:,a0] return np.reshape(chi, alpha.shape) - if fit_upsampled_FFT: - # Get mean FFT of BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor - else: - # Get mean FFT of upsampled BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) - - # FFT coordinates - qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) - qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) - qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha = np.sqrt(qr2) * self._wavelength - self.theta = np.arctan2(qy[None,:],qx[:,None]) - # initial coefficients and plotting intensity range mask - C1_dimensionless = self.aberration_C1 * 0.5 * self._wavelength + C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength coefs = np.zeros(self.aber_num) ind = np.argmin( np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) ) - coefs[ind] = C1_dimensionless - plot_mask = self.alpha > np.sqrt(np.pi/np.abs(C1_dimensionless)) - angular_mask = np.cos(4.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.5 + coefs[ind] = C10_dimensionless + plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) + # plot_mask[:] = True + angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: - pass + # scoring function to minimize - mean value of zero crossing regions of FFT + def score_CTF(coefs): + im_CTF = np.abs(calc_CTF(self.alpha,*coefs)) + mask = np.logical_and( + im_CTF > 0.5*np.pi, + im_CTF < (max_num_rings+0.5)*np.pi, + ) + if np.any(mask): + weights = np.cos(im_CTF[mask])**4 + return np.sum(weights*im_FFT[mask]*self.alpha[mask]**fit_power_alpha) / np.sum(weights) + else: + return np.inf + + for max_num_rings in range(1,fit_max_num_rings+1): + # minimization + res = minimize( + score_CTF, + coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method = 'BFGS', + tol = 1e-8, + ) + coefs = res.x + + # basis = np.vstack(( + # self.alpha.ravel(), + # im_FFT.ravel() + # )) + # print(basis.shape) + # score = score_CTF(self.alpha,coefs*1) + # print(score) + + + # im_CTF = calc_CTF(self.alpha,*coefs) + # im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; + # im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold @@ -1495,20 +1537,28 @@ def calc_CTF(alpha, *coefs): # Plot the CTF comparison between experiment and fit if plot_CTF_compare: # Generate FFT plotting image - int_range = (np.min(im_fft[plot_mask]),np.max(im_fft[plot_mask])) - int_range = (int_range[0],(int_range[1]-int_range[0])*0.5 + int_range[0]) + im_scale = im_FFT * self.alpha**fit_power_alpha + # int_range = (np.min(im_scale[plot_mask]),np.max(im_scale[plot_mask])) + int_vals = np.sort(im_scale.ravel()) + int_range = ( + int_vals[np.round(0.02*im_scale.size).astype('int')], + int_vals[np.round(0.98*im_scale.size).astype('int')], + ) + + int_range = (int_range[0],(int_range[1]-int_range[0])*1.0 + int_range[0]) im_scale = np.clip( - (np.fft.fftshift(im_fft) - int_range[0]) / (int_range[1] - int_range[0]), + (np.fft.fftshift(im_scale) - int_range[0]) / (int_range[1] - int_range[0]), 0,1) # im_scale = im_scale**0.5 im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings im_CTF = calc_CTF(self.alpha,*coefs) + im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; # im_CTF = np.sin(im_CTF)**2 # im_CTF = np.fft.fftshift(im_CTF) # print(np.max(im_CTF)) - im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold + im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold im_CTF[np.logical_not(plot_mask)] = 0 im_CTF = np.fft.fftshift(im_CTF * angular_mask) im_plot[:,:,0] += im_CTF From b2cbede265a97cf20b3c36d2e8d3edb9936b87b9 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 16:27:00 -0700 Subject: [PATCH 45/62] Updating outputs --- py4DSTEM/process/phase/iterative_parallax.py | 404 ++++++------------- 1 file changed, 121 insertions(+), 283 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 243a8ec5b..d4ed2a80a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1335,77 +1335,85 @@ def aberration_fit( ### FFT fitting / plotting code ### - - if fit_upsampled_FFT: - # Get mean FFT of BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor - else: - # Get mean FFT of upsampled BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) - - # FFT coordinates - qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) - qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) - qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha = np.sqrt(qr2) * self._wavelength - self.theta = np.arctan2(qy[None,:],qx[:,None]) - - # Aberration coefs - mn = [] - for m in range(0,fit_aber_order_max//2+1): - n_max = np.floor(fit_aber_order_max-2*m).astype('int') - - for n in range(0,n_max+1): - if m + n > 1 or (m > 0 and n == 0): - if n == 0: - mn.append([m,n,0]) - else: - mn.append([m,n,0]) - mn.append([m,n,1]) - self.aber_mn = np.array(mn) - - # Aberration basis - self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) - for a0 in range(self.aber_num): - if self.aber_mn[a0,1] == 0: - # Radially symmetric basis - self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) - elif self.aber_mn[a0,2] == 0: - # cos coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + if fit_CTF_FFT or plot_CTF_compare: + if fit_upsampled_FFT: + # Get mean FFT of BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor else: - # sin coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.sin(self.aber_mn[a0,1] * self.theta.ravel()) - - # CTF function - def calc_CTF(alpha, *coefs): - chi = np.zeros_like(alpha.ravel()) - for a0 in range(len(coefs)): - chi += coefs[a0] * self.aber_basis[:,a0] - return np.reshape(chi, alpha.shape) - - # initial coefficients and plotting intensity range mask - C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength - coefs = np.zeros(self.aber_num) - ind = np.argmin( - np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - ) - coefs[ind] = C10_dimensionless - plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) - # plot_mask[:] = True - angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 + # Get mean FFT of upsampled BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + # FFT coordinates + qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) + qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha = np.sqrt(qr2) * self._wavelength + self.theta = np.arctan2(qy[None,:],qx[:,None]) + + # Aberration coefs + mn = [] + for m in range(0,fit_aber_order_max//2+1): + n_max = np.floor(fit_aber_order_max-2*m).astype('int') + + for n in range(0,n_max+1): + if m + n > 1 or (m > 0 and n == 0): + if n == 0: + mn.append([m,n,0]) + else: + mn.append([m,n,0]) + mn.append([m,n,1]) + self.aber_mn = np.array(mn) + self.aber_mn = self.aber_mn[np.argsort(self.aber_mn[:,1]),:] + # self.aber_mn = self.aber_mn[np.lexsort(( + # self.aber_mn[:,0], + # self.aber_mn[:,2], + # self.aber_mn[:,1], + # ))] + sub = self.aber_mn[:,1] > 0 + self.aber_mn[sub,:] = self.aber_mn[sub,:][np.argsort(self.aber_mn[sub,0]),:] + + # Aberration basis + self.aber_num = self.aber_mn.shape[0] + self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) + for a0 in range(self.aber_num): + if self.aber_mn[a0,1] == 0: + # Radially symmetric basis + self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) + elif self.aber_mn[a0,2] == 0: + # cos coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + else: + # sin coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # CTF function + def calc_CTF(alpha, *coefs): + chi = np.zeros_like(alpha.ravel()) + for a0 in range(len(coefs)): + chi += coefs[a0] * self.aber_basis[:,a0] + return np.reshape(chi, alpha.shape) + + # initial coefficients and plotting intensity range mask + C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength + coefs = np.zeros(self.aber_num) + ind = np.argmin( + np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + ) + coefs[ind] = C10_dimensionless + plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) + # plot_mask[:] = True + angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: @@ -1422,142 +1430,36 @@ def score_CTF(coefs): else: return np.inf - for max_num_rings in range(1,fit_max_num_rings+1): - # minimization - res = minimize( - score_CTF, - coefs, - # method = 'Nelder-Mead', - # method = 'CG', - method = 'BFGS', - tol = 1e-8, - ) - coefs = res.x - - # basis = np.vstack(( - # self.alpha.ravel(), - # im_FFT.ravel() - # )) - # print(basis.shape) - # score = score_CTF(self.alpha,coefs*1) - # print(score) - - - # im_CTF = calc_CTF(self.alpha,*coefs) - # im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; - # im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold - - - - - - # fitting image - # im_fit = im_fft * self.q_weight - # if medfilt_size is not None: - # im_fit = np.fft.ifftshift(medfilt2d( - # np.fft.fftshift(im_fit), - # medfilt_size)) - - # # initial coefs - # int_max = np.max(im_fit) - # sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) - # coefs = np.zeros(5 + self.aber_num) - # lb = np.zeros(5 + self.aber_num) - # ub = np.ones(5 + self.aber_num) * np.inf - # coefs[0] = 1e-3 - # coefs[1] = int_max * 0.1 - # coefs[2] = sigma_init - # coefs[3] = int_max * 0.9 - # coefs[4] = sigma_init - # lb[5:] = -np.inf - # # initial C1 value (defocus) - # ind = np.argmin( - # np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - # ) - # C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength - # coefs[ind + 5] = C1_dimensionless - - # Fitting mask - # fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) - # basis_masked = self.aber_basis[fit_mask.ravel(),:] - - # Define fitting functions - - # def calc_CTF_mag(alpha2, *coefs): - # int0 = coefs[0] - # int1 = coefs[1] - # sigma1 = coefs[2] - # int_env = coefs[3] - # sigma_env = coefs[4] - - # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - # chi = np.zeros_like(im_CTF_mag) - # for a0 in range(5,len(coefs)): - # chi += coefs[a0] * self.aber_basis[:,a0-5] - # return im_CTF_mag + np.abs(np.sin(chi)) * env - - # def calc_CTF_mag_masked(alpha2, *coefs): - # int0 = coefs[0] - # int1 = coefs[1] - # sigma1 = coefs[2] - # int_env = coefs[3] - # sigma_env = coefs[4] - - # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - # chi = np.zeros_like(im_CTF_mag) - # for a0 in range(5,len(coefs)): - # chi += coefs[a0] * basis_masked[:,a0-5] - # return im_CTF_mag + np.abs(np.sin(chi)) * env - - # # Refine aberration coefficients - # if maxfev is None: - # coefs = np.array( - # curve_fit( - # calc_CTF_mag_masked, - # self.alpha2[fit_mask], - # im_fit[fit_mask], - # p0 = tuple(coefs), - # bounds = (lb,ub), - # )[0] - # ) - # else: - # coefs = np.array( - # curve_fit( - # calc_CTF_mag_masked, - # self.alpha2[fit_mask], - # im_fit[fit_mask], - # p0 = tuple(coefs), - # bounds = (lb,ub), - # maxfev = maxfev, - # )[0] + # for max_num_rings in range(1,fit_max_num_rings+1): + # # minimization + # res = minimize( + # score_CTF, + # coefs, + # # method = 'Nelder-Mead', + # # method = 'CG', + # method = 'BFGS', + # tol = 1e-8, # ) + # coefs = res.x # Plot the CTF comparison between experiment and fit if plot_CTF_compare: # Generate FFT plotting image im_scale = im_FFT * self.alpha**fit_power_alpha - # int_range = (np.min(im_scale[plot_mask]),np.max(im_scale[plot_mask])) int_vals = np.sort(im_scale.ravel()) int_range = ( int_vals[np.round(0.02*im_scale.size).astype('int')], int_vals[np.round(0.98*im_scale.size).astype('int')], ) - int_range = (int_range[0],(int_range[1]-int_range[0])*1.0 + int_range[0]) im_scale = np.clip( (np.fft.fftshift(im_scale) - int_range[0]) / (int_range[1] - int_range[0]), 0,1) - # im_scale = im_scale**0.5 im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings im_CTF = calc_CTF(self.alpha,*coefs) im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; - # im_CTF = np.sin(im_CTF)**2 - # im_CTF = np.fft.fftshift(im_CTF) - # print(np.max(im_CTF)) im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold im_CTF[np.logical_not(plot_mask)] = 0 im_CTF = np.fft.fftshift(im_CTF * angular_mask) @@ -1569,107 +1471,13 @@ def score_CTF(coefs): fig,ax = plt.subplots(figsize=(12,6)) ax.imshow( im_plot, - # np.fft.fftshift(np.reshape(self.aber_basis[:,1],im_CTF.shape)) - # angular_mask, - # np.hstack(( - # im_scale, - # im_CTF - # )) - # im_ctf ) - - # ax.imshow( - # im_plot / np.max(im_plot), - # vmin = 0, - # vmax = 1, - # cmap = 'gray', - # ) - # ax.imshow( \ - # np.fft.fftshift( - # np.mod(np.reshape(self.aber_basis[:,2],im_fft.shape)+np.pi,2*np.pi)-np.pi - # )) - - - # # Get polar mean from FFT of BF reconstruction - # im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # # coordinates - # kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - # ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - # kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - # k_max = xp.max(kra) / np.sqrt(2.0) - # k_num_bins = int(xp.ceil(k_max / plot_dk)) - # k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # # histogram - # k_ind = kra / plot_dk - # kf = np.floor(k_ind).astype("int") - # dk = k_ind - kf - # sub = kf <= k_num_bins - # hist_exp = xp.bincount( - # kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins - # ) - # hist_norm = xp.bincount( - # kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins - # ) - # sub = kf <= k_num_bins - 1 - - # hist_exp += xp.bincount( - # kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins - # ) - # hist_norm += xp.bincount( - # kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins - # ) - - # # KDE and normalizing - # k_sigma = plot_dk / plot_k_sigma - # hist_exp[0] = 0.0 - # hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - # hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - # hist_exp /= hist_norm - - # # CTF comparison - # CTF_fit = xp.sin( - # (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 - # ) - - # # plotting input - log scale - # min_hist_val = xp.max(hist_exp) * 1e-3 - # hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - # hist_plot -= xp.min(hist_plot) - # hist_plot /= xp.max(hist_plot) - - # hist_plot = asnumpy(hist_plot) - # k_bins = asnumpy(k_bins) - # CTF_fit = asnumpy(CTF_fit) - - # fig, ax = plt.subplots(figsize=(8, 4)) - - # ax.fill_between( - # k_bins, - # hist_plot, - # color=(0.7, 0.7, 0.7, 1), - # ) - - # ax.plot( - # k_bins, - # np.clip(CTF_fit, 0.0, np.inf), - # color=(1, 0, 0, 1), - # linewidth=2, - # ) - # ax.plot( - # k_bins, - # np.clip(-CTF_fit, 0.0, np.inf), - # color=(0, 0.5, 1, 1), - # linewidth=2, - # ) - # ax.set_xlim([0, k_bins[-1]]) - # ax.set_ylim([0, 1.05]) - - # Print results if self._verbose: + if fit_CTF_FFT: + print('Initial Aberration coefficients') + print('-------------------------------') print( ( "Rotation of Q w.r.t. R = " @@ -1686,6 +1494,36 @@ def score_CTF(coefs): print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + if fit_CTF_FFT: + # radial_order = 2 * self.aber_mn[a0,0] + + + print() + print('Refined Aberration coefficients') + print('-------------------------------') + print('radial annular dir. coefs') + print('order order ') + print('------ ------- ---- -----') + + for a0 in range(self.aber_mn.shape[0]): + if self.aber_mn[a0,1] == 0: + print( + str(self.aber_mn[a0,0]) + \ + ' 0 - ' + \ + str(np.round(coefs[a0]).astype('int')) ) + elif self.aber_mn[a0,2] == 0: + print( + str(self.aber_mn[a0,0]) + \ + ' ' + \ + str(self.aber_mn[a0,1]) + \ + ' x ' + \ + str(np.round(coefs[a0]).astype('int')) ) + else: + print( + str(self.aber_mn[a0,0]) + \ + ' ' + \ + str(self.aber_mn[a0,1]) + \ + ' y ' + \ + str(np.round(coefs[a0]).astype('int')) ) def aberration_correct( self, From 270551459904b4ccb23c07bf003dd0c1b99fec3c Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 16:44:47 -0700 Subject: [PATCH 46/62] Adding outputs, plotting --- py4DSTEM/process/phase/iterative_parallax.py | 107 +++++++++++-------- 1 file changed, 63 insertions(+), 44 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index d4ed2a80a..25bca20e8 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1403,14 +1403,15 @@ def calc_CTF(alpha, *coefs): for a0 in range(len(coefs)): chi += coefs[a0] * self.aber_basis[:,a0] return np.reshape(chi, alpha.shape) + self.calc_CTF = calc_CTF # initial coefficients and plotting intensity range mask C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength - coefs = np.zeros(self.aber_num) + self.aber_coefs = np.zeros(self.aber_num) ind = np.argmin( np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) ) - coefs[ind] = C10_dimensionless + self.aber_coefs[ind] = C10_dimensionless plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) # plot_mask[:] = True angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 @@ -1430,17 +1431,17 @@ def score_CTF(coefs): else: return np.inf - # for max_num_rings in range(1,fit_max_num_rings+1): - # # minimization - # res = minimize( - # score_CTF, - # coefs, - # # method = 'Nelder-Mead', - # # method = 'CG', - # method = 'BFGS', - # tol = 1e-8, - # ) - # coefs = res.x + for max_num_rings in range(1,fit_max_num_rings+1): + # minimization + res = minimize( + score_CTF, + self.aber_coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method = 'BFGS', + tol = 1e-8, + ) + self.aber_coefs = res.x # Plot the CTF comparison between experiment and fit if plot_CTF_compare: @@ -1458,7 +1459,7 @@ def score_CTF(coefs): im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings - im_CTF = calc_CTF(self.alpha,*coefs) + im_CTF = calc_CTF(self.alpha,*self.aber_coefs) im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold im_CTF[np.logical_not(plot_mask)] = 0 @@ -1495,7 +1496,7 @@ def score_CTF(coefs): print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") if fit_CTF_FFT: - # radial_order = 2 * self.aber_mn[a0,0] + + radial_order = 2 * self.aber_mn[:,0] + self.aber_mn[:,1] print() print('Refined Aberration coefficients') @@ -1507,26 +1508,27 @@ def score_CTF(coefs): for a0 in range(self.aber_mn.shape[0]): if self.aber_mn[a0,1] == 0: print( - str(self.aber_mn[a0,0]) + \ + str(radial_order[a0]) + \ ' 0 - ' + \ - str(np.round(coefs[a0]).astype('int')) ) + str(np.round(self.aber_coefs[a0]).astype('int')) ) elif self.aber_mn[a0,2] == 0: print( - str(self.aber_mn[a0,0]) + \ + str(radial_order[a0]) + \ ' ' + \ str(self.aber_mn[a0,1]) + \ ' x ' + \ - str(np.round(coefs[a0]).astype('int')) ) + str(np.round(self.aber_coefs[a0]).astype('int')) ) else: print( - str(self.aber_mn[a0,0]) + \ + str(radial_order[a0]) + \ ' ' + \ str(self.aber_mn[a0,1]) + \ ' y ' + \ - str(np.round(coefs[a0]).astype('int')) ) + str(np.round(self.aber_coefs[a0]).astype('int')) ) def aberration_correct( self, + use_FFT_fit = True, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, @@ -1541,6 +1543,8 @@ def aberration_correct( Parameters ---------- + use_FFT_fit: bool + Use the CTF fitted to the zero crossings of the FFT. plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional @@ -1581,30 +1585,9 @@ def aberration_correct( ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) - - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio - ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(im) * CTF_corr + if use_FFT_fit: + sin_chi = np.sin(self.calc_CTF(self.alpha,*self.aber_coefs)) - else: - # CTF without tilt correction (beyond the parallax operator) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 @@ -1616,6 +1599,42 @@ def aberration_correct( im_fft_corr /= 1 + (kra2**k_info_power) / ( (k_info_limit) ** (2 * k_info_power) ) + else: + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) From 7eae9484b4e46bf6ce0bc2af5519aef4324fff9e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 19 Oct 2023 10:19:02 -0700 Subject: [PATCH 47/62] finally works --- py4DSTEM/process/phase/iterative_parallax.py | 386 +++++++++++++------ 1 file changed, 264 insertions(+), 122 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 25bca20e8..ef6ae9f5b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -13,6 +13,7 @@ from py4DSTEM import DataCube from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom from py4DSTEM.visualize import show @@ -1236,10 +1237,10 @@ def subpixel_alignment( ) reciprocal_extent = [ - -self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, - self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, - self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, - -self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, + -0.5/(self._scan_sampling[1]/self._kde_upsample_factor), + 0.5/(self._scan_sampling[1]/self._kde_upsample_factor), + 0.5/(self._scan_sampling[0]/self._kde_upsample_factor), + -0.5/(self._scan_sampling[0]/self._kde_upsample_factor), ] show( @@ -1269,22 +1270,23 @@ def subpixel_alignment( def aberration_fit( self, - fit_CTF_FFT = True, - fit_CTF_threshold = 0.25, - fit_upsampled_FFT = True, - fit_aber_order_max = 2, - fit_max_num_rings = 6, - fit_power_alpha = 2.0, - # fit_maxfev = None, - plot_CTF_compare: bool = False, + fit_BF_shifts:bool = True, + fit_CTF_FFT:bool = False, + fit_aberrations_max_radial_order:int=3, + fit_aberrations_max_angular_order:int=4, + fit_aberrations_min_radial_order:int=1, + fit_aberrations_min_angular_order:int=0, + fit_max_thon_rings:int = 6, + fit_power_alpha:float = 2.0, + plot_CTF_comparison: bool = None, + plot_BF_shifts_comparison: bool = None, + upsampled:bool=True, ): """ Fit aberrations to the measured image shifts. Parameters ---------- - plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies. fit_CTF_FFT: bool Set to True to directly fit aberrations in the FFT of the upsampled BF image (if available). Note that this method relies on visible zero @@ -1296,14 +1298,14 @@ def aberration_fit( Max radial order for fitting of aberrations. ctf_threshold: float CTF fitting minimizes value at CTF zero crossings (Thon ring minima). - + plot_CTF_compare: bool, optional + If True, the fitted CTF is plotted against the reconstructed frequencies. """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter - ### initial aberration fit ### + ### First pass # Convert real space shifts to Angstroms self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) @@ -1326,127 +1328,210 @@ def aberration_fit( self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 self.aberration_A1x = ( m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 # factor /2 for A1 astigmatism? /4? + ) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + ### Second pass + # Aberration coefs + mn = [] - ### FFT fitting / plotting code ### - if fit_CTF_FFT or plot_CTF_compare: - if fit_upsampled_FFT: - # Get mean FFT of BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + for m in range(fit_aberrations_min_radial_order,fit_aberrations_max_radial_order+1): + n_max = np.minimum(fit_aberrations_max_angular_order,m+1) + for n in range(fit_aberrations_min_angular_order,n_max+1): + if (m+n) % 2: + mn.append([m,n,0]) + if n > 0: + mn.append([m,n,1]) + + self._aberrations_mn = np.array(mn) + self._aberrations_mn = self._aberrations_mn[np.argsort(self._aberrations_mn[:,1]),:] + + sub = self._aberrations_mn[:,1] > 0 + self._aberrations_mn[sub,:] = self._aberrations_mn[sub,:][np.argsort(self._aberrations_mn[sub,0]),:] + self._aberrations_num = self._aberrations_mn.shape[0] + + if plot_CTF_comparison is None: + if fit_CTF_FFT: + plot_CTF_comparison = True + + if plot_BF_shifts_comparison is None: + if fit_BF_shifts: + plot_BF_shifts_comparison = True + + # Thon Rings Fitting + if fit_CTF_FFT or plot_CTF_comparison: + if upsampled and hasattr(self,"_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor else: - # Get mean FFT of upsampled BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + upsampled=False - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0],sx) + qy = xp.fft.fftfreq(im_FFT.shape[1],sy) + qr2 = qx[:,None]**2 + qy[None,:]**2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None,:],qx[:,None]) + + # Aberration basis + self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) + for a0 in range(self._aberrations_num): + m,n,a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() + elif a == 0: + # cos coef + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() + else: + # sin coef + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() + + # global scaling + self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_surface_shape = alpha.shape + plot_mask = qr2 > np.pi**2/4/np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta)**2 < 0.25 + + # Direct Shifts Fitting + elif fit_BF_shifts: + # FFT coordinates - qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) - qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) + sx = 1/(self._reciprocal_sampling[0]*self._region_of_interest_shape[0]) + sy = 1/(self._reciprocal_sampling[1]*self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0],sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1],sy) qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha = np.sqrt(qr2) * self._wavelength - self.theta = np.arctan2(qy[None,:],qx[:,None]) - - # Aberration coefs - mn = [] - for m in range(0,fit_aber_order_max//2+1): - n_max = np.floor(fit_aber_order_max-2*m).astype('int') - - for n in range(0,n_max+1): - if m + n > 1 or (m > 0 and n == 0): - if n == 0: - mn.append([m,n,0]) - else: - mn.append([m,n,0]) - mn.append([m,n,1]) - self.aber_mn = np.array(mn) - self.aber_mn = self.aber_mn[np.argsort(self.aber_mn[:,1]),:] - # self.aber_mn = self.aber_mn[np.lexsort(( - # self.aber_mn[:,0], - # self.aber_mn[:,2], - # self.aber_mn[:,1], - # ))] - sub = self.aber_mn[:,1] > 0 - self.aber_mn[sub,:] = self.aber_mn[sub,:][np.argsort(self.aber_mn[sub,0]),:] + + u = qx[:,None]*self._wavelength + v = qy[None,:]*self._wavelength + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None,:],qx[:,None]) # Aberration basis - self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) - for a0 in range(self.aber_num): - if self.aber_mn[a0,1] == 0: + self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size,self._aberrations_num)) + for a0 in range(self._aberrations_num): + m,n,a = self._aberrations_mn[a0] + + if n == 0: # Radially symmetric basis - self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) - elif self.aber_mn[a0,2] == 0: + self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() + self._aberrations_basis_du[:,a0] = (u*alpha**(m-1)).ravel() + self._aberrations_basis_dv[:,a0] = (v*alpha**(m-1)).ravel() + + elif a == 0: # cos coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() + self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.cos(n*theta) + n*v*xp.sin(n*theta))/(m+1)).ravel() + self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.cos(n*theta) - n*u*xp.sin(n*theta))/(m+1)).ravel() + else: # sin coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.sin(self.aber_mn[a0,1] * self.theta.ravel()) - - # CTF function - def calc_CTF(alpha, *coefs): - chi = np.zeros_like(alpha.ravel()) - for a0 in range(len(coefs)): - chi += coefs[a0] * self.aber_basis[:,a0] - return np.reshape(chi, alpha.shape) - self.calc_CTF = calc_CTF - - # initial coefficients and plotting intensity range mask - C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength - self.aber_coefs = np.zeros(self.aber_num) - ind = np.argmin( - np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - ) - self.aber_coefs[ind] = C10_dimensionless - plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) - # plot_mask[:] = True - angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() + self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.sin(n*theta) - n*v*xp.cos(n*theta))/(m+1)).ravel() + self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.sin(n*theta) + n*u*xp.cos(n*theta))/(m+1)).ravel() + + # global scaling + self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_surface_shape = alpha.shape + + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:,0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:,a0] + return xp.reshape(chi, alpha_shape) + self._calculate_CTF = calculate_CTF + + # initial coefficients and plotting intensity range mask + self._aberrations_coefs = np.zeros(self._aberrations_num) + ind = np.argmin( + np.abs(self._aberrations_mn[:,0] - 1.0) + self._aberrations_mn[:,1] + ) + self._aberrations_coefs[ind] = self.aberration_C1 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: + # scoring function to minimize - mean value of zero crossing regions of FFT def score_CTF(coefs): - im_CTF = np.abs(calc_CTF(self.alpha,*coefs)) - mask = np.logical_and( + im_CTF = xp.abs(self._calculate_CTF(self._aberrations_surface_shape,*coefs)) + mask = xp.logical_and( im_CTF > 0.5*np.pi, im_CTF < (max_num_rings+0.5)*np.pi, ) if np.any(mask): - weights = np.cos(im_CTF[mask])**4 - return np.sum(weights*im_FFT[mask]*self.alpha[mask]**fit_power_alpha) / np.sum(weights) + weights = xp.cos(im_CTF[mask])**4 + return asnumpy(xp.sum(weights*im_FFT[mask]*alpha[mask]**fit_power_alpha) / xp.sum(weights)) else: return np.inf - for max_num_rings in range(1,fit_max_num_rings+1): + for max_num_rings in range(1,fit_max_thon_rings+1): # minimization res = minimize( score_CTF, - self.aber_coefs, + self._aberrations_coefs, # method = 'Nelder-Mead', # method = 'CG', method = 'BFGS', tol = 1e-8, ) - self.aber_coefs = res.x + self._aberrations_coefs = res.x + + # Refinement using CTF fitting / Thon rings + elif fit_BF_shifts: + + # Gradient basis + corner_indices = self._xy_inds-xp.asarray(self._region_of_interest_shape//2) + raveled_indices = np.ravel_multi_index(corner_indices.T,self._region_of_interest_shape,mode='wrap') + gradients = xp.vstack(( + self._aberrations_basis_du[raveled_indices,:], + self._aberrations_basis_dv[raveled_indices,:] + )) + + # Untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang,xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq(gradients, rotated_shifts,rcond=None)[:2] + + # Transposed fit + transposed_shifts = xp.flip(self._xy_shifts_Ang,axis=1) + m_T = asnumpy(xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0]) + m_rotation_T, _ = polar(m_T, side="right") + rotation_Q_to_R_rads_T = -1 * np.arctan2(m_rotation_T[1, 0], m_rotation_T[0, 0]) + if np.abs(np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi) > ( + np.pi * 0.5 + ): + rotation_Q_to_R_rads_T = ( + np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + ) + + tf = AffineTransform(angle=rotation_Q_to_R_rads_T) + rotated_shifts_T = tf(transposed_shifts,xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq(gradients, rotated_shifts_T,rcond=None)[:2] + if res_T.sum() < res.sum(): + self._aberrations_coefs = asnumpy(aberrations_coefs_T) + self._rotated_shifts = rotated_shifts_T + else: + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts + + # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: + if plot_CTF_comparison: # Generate FFT plotting image - im_scale = im_FFT * self.alpha**fit_power_alpha + im_scale = asnumpy(im_FFT * alpha**fit_power_alpha) int_vals = np.sort(im_scale.ravel()) int_range = ( int_vals[np.round(0.02*im_scale.size).astype('int')], @@ -1459,24 +1544,77 @@ def score_CTF(coefs): im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings - im_CTF = calc_CTF(self.alpha,*self.aber_coefs) - im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; - im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold - im_CTF[np.logical_not(plot_mask)] = 0 - im_CTF = np.fft.fftshift(im_CTF * angular_mask) + im_CTF = self._calculate_CTF(self._aberrations_surface_shape,*self._aberrations_coefs) + im_CTF_cos = xp.cos(xp.abs(im_CTF))**4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings+0.5)*np.pi] = np.pi/2 + im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 + im_CTF[xp.logical_not(plot_mask)] = 0 + + im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) im_plot[:,:,0] += im_CTF im_plot[:,:,1] -= im_CTF im_plot[:,:,2] -= im_CTF im_plot = np.clip(im_plot,0,1) - fig,ax = plt.subplots(figsize=(12,6)) - ax.imshow( + fig,(ax1,ax2) = plt.subplots(1,2,figsize=(12,6)) + ax1.imshow( im_plot, + vmin=int_range[0], + vmax=int_range[1] + ) + + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap='gray' + ) + + fig.tight_layout() + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + + if not fit_BF_shifts: + raise ValueError() + + measured_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + measured_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[:self._xy_inds.shape[0]] + + measured_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + measured_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[self._xy_inds.shape[0]:] + + fitted_shifts = xp.tensordot(gradients,xp.array(self._aberrations_coefs),axes=1) + + fitted_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + fitted_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[:self._xy_inds.shape[0]] + + fitted_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + fitted_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[self._xy_inds.shape[0]:] + + max_shift = xp.max(xp.array([xp.abs(measured_shifts_sx).max(),xp.abs(measured_shifts_sy).max(),xp.abs(fitted_shifts_sx).max(),xp.abs(fitted_shifts_sy).max()])) + + show( + [ + [ + asnumpy(measured_shifts_sx), + asnumpy(measured_shifts_sy) + ], + [ + asnumpy(fitted_shifts_sx), + asnumpy(fitted_shifts_sy) + ], + ], + cmap='PiYG', + vmin=-max_shift, + vmax=max_shift, + intensity_range='absolute', + axsize=(4,4), + ticks=False, + title=["Measured Vertical Shifts","Measured Horizontal Shifts","Fitted Vertical Shifts","Fitted Horizontal Shifts"] ) # Print results if self._verbose: - if fit_CTF_FFT: + if fit_CTF_FFT or fit_BF_shifts: print('Initial Aberration coefficients') print('-------------------------------') print( @@ -1495,36 +1633,40 @@ def score_CTF(coefs): print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - if fit_CTF_FFT: - radial_order = 2 * self.aber_mn[:,0] + self.aber_mn[:,1] + if fit_CTF_FFT or fit_BF_shifts: print() print('Refined Aberration coefficients') print('-------------------------------') - print('radial annular dir. coefs') + print('radial angular dir. coefs') print('order order ') print('------ ------- ---- -----') - for a0 in range(self.aber_mn.shape[0]): - if self.aber_mn[a0,1] == 0: + for a0 in range(self._aberrations_mn.shape[0]): + m, n, a = self._aberrations_mn[a0] + if n == 0: print( - str(radial_order[a0]) + \ + str(m) + \ ' 0 - ' + \ - str(np.round(self.aber_coefs[a0]).astype('int')) ) - elif self.aber_mn[a0,2] == 0: + str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + elif a == 0: print( - str(radial_order[a0]) + \ + str(m) + \ ' ' + \ - str(self.aber_mn[a0,1]) + \ + str(n) + \ ' x ' + \ - str(np.round(self.aber_coefs[a0]).astype('int')) ) + str(np.round(self._aberrations_coefs[a0]).astype('int')) ) else: print( - str(radial_order[a0]) + \ + str(m) + \ ' ' + \ - str(self.aber_mn[a0,1]) + \ + str(n) + \ ' y ' + \ - str(np.round(self.aber_coefs[a0]).astype('int')) ) + str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() def aberration_correct( self, @@ -1909,7 +2051,7 @@ def _visualize_figax( **kwargs, ) - def _visualize_shifts( + def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, From d1f6efb6d4bce315a7596d4f4f5e6a21ad26845f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 19 Oct 2023 10:31:37 -0700 Subject: [PATCH 48/62] some support for aberration correct --- py4DSTEM/process/phase/iterative_parallax.py | 442 +++++++++++-------- 1 file changed, 264 insertions(+), 178 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index ef6ae9f5b..5f7e1c25e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1237,10 +1237,10 @@ def subpixel_alignment( ) reciprocal_extent = [ - -0.5/(self._scan_sampling[1]/self._kde_upsample_factor), - 0.5/(self._scan_sampling[1]/self._kde_upsample_factor), - 0.5/(self._scan_sampling[0]/self._kde_upsample_factor), - -0.5/(self._scan_sampling[0]/self._kde_upsample_factor), + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), ] show( @@ -1270,17 +1270,17 @@ def subpixel_alignment( def aberration_fit( self, - fit_BF_shifts:bool = True, - fit_CTF_FFT:bool = False, - fit_aberrations_max_radial_order:int=3, - fit_aberrations_max_angular_order:int=4, - fit_aberrations_min_radial_order:int=1, - fit_aberrations_min_angular_order:int=0, - fit_max_thon_rings:int = 6, - fit_power_alpha:float = 2.0, + fit_BF_shifts: bool = True, + fit_CTF_FFT: bool = False, + fit_aberrations_max_radial_order: int = 3, + fit_aberrations_max_angular_order: int = 4, + fit_aberrations_min_radial_order: int = 1, + fit_aberrations_min_angular_order: int = 0, + fit_max_thon_rings: int = 6, + fit_power_alpha: float = 2.0, plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, - upsampled:bool=True, + upsampled: bool = True, ): """ Fit aberrations to the measured image shifts. @@ -1288,8 +1288,8 @@ def aberration_fit( Parameters ---------- fit_CTF_FFT: bool - Set to True to directly fit aberrations in the FFT of the upsampled BF - image (if available). Note that this method relies on visible zero + Set to True to directly fit aberrations in the FFT of the upsampled BF + image (if available). Note that this method relies on visible zero crossings in the FFT, and will not work if they are not present. fit_upsampled_FFT: bool If True, we aberration fit is performed on the upsampled BF image. @@ -1326,9 +1326,7 @@ def aberration_fit( ) m_aberration = -1.0 * m_aberration self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = ( - m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 ### Second pass @@ -1336,32 +1334,38 @@ def aberration_fit( # Aberration coefs mn = [] - for m in range(fit_aberrations_min_radial_order,fit_aberrations_max_radial_order+1): - n_max = np.minimum(fit_aberrations_max_angular_order,m+1) - for n in range(fit_aberrations_min_angular_order,n_max+1): - if (m+n) % 2: - mn.append([m,n,0]) + for m in range( + fit_aberrations_min_radial_order, fit_aberrations_max_radial_order + 1 + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) if n > 0: - mn.append([m,n,1]) + mn.append([m, n, 1]) self._aberrations_mn = np.array(mn) - self._aberrations_mn = self._aberrations_mn[np.argsort(self._aberrations_mn[:,1]),:] - - sub = self._aberrations_mn[:,1] > 0 - self._aberrations_mn[sub,:] = self._aberrations_mn[sub,:][np.argsort(self._aberrations_mn[sub,0]),:] + self._aberrations_mn = self._aberrations_mn[ + np.argsort(self._aberrations_mn[:, 1]), : + ] + + sub = self._aberrations_mn[:, 1] > 0 + self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ + np.argsort(self._aberrations_mn[sub, 0]), : + ] self._aberrations_num = self._aberrations_mn.shape[0] if plot_CTF_comparison is None: if fit_CTF_FFT: plot_CTF_comparison = True - + if plot_BF_shifts_comparison is None: if fit_BF_shifts: plot_BF_shifts_comparison = True # Thon Rings Fitting if fit_CTF_FFT or plot_CTF_comparison: - if upsampled and hasattr(self,"_kde_upsample_factor"): + if upsampled and hasattr(self, "_kde_upsample_factor"): im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) sx = self._scan_sampling[0] / self._kde_upsample_factor sy = self._scan_sampling[1] / self._kde_upsample_factor @@ -1370,145 +1374,188 @@ def aberration_fit( im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) sx = self._scan_sampling[0] sy = self._scan_sampling[1] - upsampled=False + upsampled = False # FFT coordinates - qx = xp.fft.fftfreq(im_FFT.shape[0],sx) - qy = xp.fft.fftfreq(im_FFT.shape[1],sy) - qr2 = qx[:,None]**2 + qy[None,:]**2 + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None,:],qx[:,None]) + theta = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) for a0 in range(self._aberrations_num): - m,n,a = self._aberrations_mn[a0] + m, n, a = self._aberrations_mn[a0] if n == 0: # Radially symmetric basis - self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() elif a == 0: # cos coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() else: # sin coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() - + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + # global scaling - self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_basis *= 2 * np.pi / self._wavelength self._aberrations_surface_shape = alpha.shape - plot_mask = qr2 > np.pi**2/4/np.abs(self.aberration_C1) - angular_mask = np.cos(8.0 * theta)**2 < 0.25 + plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta) ** 2 < 0.25 # Direct Shifts Fitting elif fit_BF_shifts: - # FFT coordinates - sx = 1/(self._reciprocal_sampling[0]*self._region_of_interest_shape[0]) - sy = 1/(self._reciprocal_sampling[1]*self._region_of_interest_shape[1]) - qx = xp.fft.fftfreq(self._region_of_interest_shape[0],sx) - qy = xp.fft.fftfreq(self._region_of_interest_shape[1],sy) - qr2 = qx[:,None]**2 + qy[None,:]**2 - - u = qx[:,None]*self._wavelength - v = qy[None,:]*self._wavelength + sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) + sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + u = qx[:, None] * self._wavelength + v = qy[None, :] * self._wavelength alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None,:],qx[:,None]) + theta = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) - self._aberrations_basis_du = xp.zeros((alpha.size,self._aberrations_num)) - self._aberrations_basis_dv = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) for a0 in range(self._aberrations_num): - m,n,a = self._aberrations_mn[a0] + m, n, a = self._aberrations_mn[a0] if n == 0: # Radially symmetric basis - self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() - self._aberrations_basis_du[:,a0] = (u*alpha**(m-1)).ravel() - self._aberrations_basis_dv[:,a0] = (v*alpha**(m-1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() elif a == 0: # cos coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() - self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.cos(n*theta) + n*v*xp.sin(n*theta))/(m+1)).ravel() - self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.cos(n*theta) - n*u*xp.sin(n*theta))/(m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() else: # sin coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() - self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.sin(n*theta) - n*v*xp.cos(n*theta))/(m+1)).ravel() - self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.sin(n*theta) + n*u*xp.cos(n*theta))/(m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() # global scaling - self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_basis *= 2 * np.pi / self._wavelength self._aberrations_surface_shape = alpha.shape # CTF function def calculate_CTF(alpha_shape, *coefs): - chi = xp.zeros_like(self._aberrations_basis[:,0]) + chi = xp.zeros_like(self._aberrations_basis[:, 0]) for a0 in range(len(coefs)): - chi += coefs[a0] * self._aberrations_basis[:,a0] + chi += coefs[a0] * self._aberrations_basis[:, a0] return xp.reshape(chi, alpha_shape) + self._calculate_CTF = calculate_CTF # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) ind = np.argmin( - np.abs(self._aberrations_mn[:,0] - 1.0) + self._aberrations_mn[:,1] + np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] ) self._aberrations_coefs[ind] = self.aberration_C1 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: - # scoring function to minimize - mean value of zero crossing regions of FFT def score_CTF(coefs): - im_CTF = xp.abs(self._calculate_CTF(self._aberrations_surface_shape,*coefs)) + im_CTF = xp.abs( + self._calculate_CTF(self._aberrations_surface_shape, *coefs) + ) mask = xp.logical_and( - im_CTF > 0.5*np.pi, - im_CTF < (max_num_rings+0.5)*np.pi, + im_CTF > 0.5 * np.pi, + im_CTF < (max_num_rings + 0.5) * np.pi, ) if np.any(mask): - weights = xp.cos(im_CTF[mask])**4 - return asnumpy(xp.sum(weights*im_FFT[mask]*alpha[mask]**fit_power_alpha) / xp.sum(weights)) + weights = xp.cos(im_CTF[mask]) ** 4 + return asnumpy( + xp.sum(weights * im_FFT[mask] * alpha[mask] ** fit_power_alpha) + / xp.sum(weights) + ) else: return np.inf - for max_num_rings in range(1,fit_max_thon_rings+1): + for max_num_rings in range(1, fit_max_thon_rings + 1): # minimization res = minimize( - score_CTF, - self._aberrations_coefs, - # method = 'Nelder-Mead', + score_CTF, + self._aberrations_coefs, + # method = 'Nelder-Mead', # method = 'CG', - method = 'BFGS', - tol = 1e-8, + method="BFGS", + tol=1e-8, ) self._aberrations_coefs = res.x - + # Refinement using CTF fitting / Thon rings elif fit_BF_shifts: - # Gradient basis - corner_indices = self._xy_inds-xp.asarray(self._region_of_interest_shape//2) - raveled_indices = np.ravel_multi_index(corner_indices.T,self._region_of_interest_shape,mode='wrap') - gradients = xp.vstack(( - self._aberrations_basis_du[raveled_indices,:], - self._aberrations_basis_dv[raveled_indices,:] - )) + corner_indices = self._xy_inds - xp.asarray( + self._region_of_interest_shape // 2 + ) + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.vstack( + ( + self._aberrations_basis_du[raveled_indices, :], + self._aberrations_basis_dv[raveled_indices, :], + ) + ) # Untransposed fit tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang,xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq(gradients, rotated_shifts,rcond=None)[:2] + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] # Transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang,axis=1) - m_T = asnumpy(xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0]) + transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) + m_T = asnumpy( + xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0] + ) m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2(m_rotation_T[1, 0], m_rotation_T[0, 0]) + rotation_Q_to_R_rads_T = -1 * np.arctan2( + m_rotation_T[1, 0], m_rotation_T[0, 0] + ) if np.abs(np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi) > ( np.pi * 0.5 ): @@ -1517,8 +1564,10 @@ def score_CTF(coefs): ) tf = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf(transposed_shifts,xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq(gradients, rotated_shifts_T,rcond=None)[:2] + rotated_shifts_T = tf(transposed_shifts, xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq( + gradients, rotated_shifts_T, rcond=None + )[:2] if res_T.sum() < res.sum(): self._aberrations_coefs = asnumpy(aberrations_coefs_T) @@ -1526,97 +1575,122 @@ def score_CTF(coefs): else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts - - + # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: # Generate FFT plotting image im_scale = asnumpy(im_FFT * alpha**fit_power_alpha) int_vals = np.sort(im_scale.ravel()) int_range = ( - int_vals[np.round(0.02*im_scale.size).astype('int')], - int_vals[np.round(0.98*im_scale.size).astype('int')], - ) - int_range = (int_range[0],(int_range[1]-int_range[0])*1.0 + int_range[0]) + int_vals[np.round(0.02 * im_scale.size).astype("int")], + int_vals[np.round(0.98 * im_scale.size).astype("int")], + ) + int_range = ( + int_range[0], + (int_range[1] - int_range[0]) * 1.0 + int_range[0], + ) im_scale = np.clip( - (np.fft.fftshift(im_scale) - int_range[0]) / (int_range[1] - int_range[0]), - 0,1) - im_plot = np.tile(im_scale[:,:,None],(1,1,3)) + (np.fft.fftshift(im_scale) - int_range[0]) + / (int_range[1] - int_range[0]), + 0, + 1, + ) + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) # Add CTF zero crossings - im_CTF = self._calculate_CTF(self._aberrations_surface_shape,*self._aberrations_coefs) - im_CTF_cos = xp.cos(xp.abs(im_CTF))**4 - im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings+0.5)*np.pi] = np.pi/2 + im_CTF = self._calculate_CTF( + self._aberrations_surface_shape, *self._aberrations_coefs + ) + im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 im_CTF[xp.logical_not(plot_mask)] = 0 im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) - im_plot[:,:,0] += im_CTF - im_plot[:,:,1] -= im_CTF - im_plot[:,:,2] -= im_CTF - im_plot = np.clip(im_plot,0,1) - - fig,(ax1,ax2) = plt.subplots(1,2,figsize=(12,6)) - ax1.imshow( - im_plot, - vmin=int_range[0], - vmax=int_range[1] - ) - - ax2.imshow( - np.fft.fftshift(asnumpy(im_CTF_cos)), - cmap='gray' - ) + im_plot[:, :, 0] += im_CTF + im_plot[:, :, 1] -= im_CTF + im_plot[:, :, 2] -= im_CTF + im_plot = np.clip(im_plot, 0, 1) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) + + ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") fig.tight_layout() - + # Plot the measured/fitted shifts comparison if plot_BF_shifts_comparison: - if not fit_BF_shifts: raise ValueError() - - measured_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - measured_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[:self._xy_inds.shape[0]] - - measured_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - measured_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[self._xy_inds.shape[0]:] - fitted_shifts = xp.tensordot(gradients,xp.array(self._aberrations_coefs),axes=1) + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[: self._xy_inds.shape[0]] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[self._xy_inds.shape[0] :] + + fitted_shifts = xp.tensordot( + gradients, xp.array(self._aberrations_coefs), axes=1 + ) - fitted_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - fitted_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[:self._xy_inds.shape[0]] + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + : self._xy_inds.shape[0] + ] - fitted_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - fitted_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[self._xy_inds.shape[0]:] + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + self._xy_inds.shape[0] : + ] - max_shift = xp.max(xp.array([xp.abs(measured_shifts_sx).max(),xp.abs(measured_shifts_sy).max(),xp.abs(fitted_shifts_sx).max(),xp.abs(fitted_shifts_sy).max()])) + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] + ) + ) show( [ - [ - asnumpy(measured_shifts_sx), - asnumpy(measured_shifts_sy) - ], - [ - asnumpy(fitted_shifts_sx), - asnumpy(fitted_shifts_sy) - ], + [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], + [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], ], - cmap='PiYG', + cmap="PiYG", vmin=-max_shift, vmax=max_shift, - intensity_range='absolute', - axsize=(4,4), + intensity_range="absolute", + axsize=(4, 4), ticks=False, - title=["Measured Vertical Shifts","Measured Horizontal Shifts","Fitted Vertical Shifts","Fitted Horizontal Shifts"] + title=[ + "Measured Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Vertical Shifts", + "Fitted Horizontal Shifts", + ], ) # Print results if self._verbose: if fit_CTF_FFT or fit_BF_shifts: - print('Initial Aberration coefficients') - print('-------------------------------') + print("Initial Aberration coefficients") + print("-------------------------------") print( ( "Rotation of Q w.r.t. R = " @@ -1634,35 +1708,37 @@ def score_CTF(coefs): print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") if fit_CTF_FFT or fit_BF_shifts: - print() - print('Refined Aberration coefficients') - print('-------------------------------') - print('radial angular dir. coefs') - print('order order ') - print('------ ------- ---- -----') + print("Refined Aberration coefficients") + print("-------------------------------") + print("radial angular dir. coefs") + print("order order ") + print("------ ------- ---- -----") for a0 in range(self._aberrations_mn.shape[0]): m, n, a = self._aberrations_mn[a0] if n == 0: print( - str(m) + \ - ' 0 - ' + \ - str(np.round(self._aberrations_coefs[a0]).astype('int')) ) - elif a == 0: + str(m) + + " 0 - " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + elif a == 0: print( - str(m) + \ - ' ' + \ - str(n) + \ - ' x ' + \ - str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + str(m) + + " " + + str(n) + + " x " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) else: print( - str(m) + \ - ' ' + \ - str(n) + \ - ' y ' + \ - str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + str(m) + + " " + + str(n) + + " y " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) if self._device == "gpu": xp._default_memory_pool.free_all_blocks() @@ -1670,7 +1746,7 @@ def score_CTF(coefs): def aberration_correct( self, - use_FFT_fit = True, + use_CTF_fit=None, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, @@ -1727,8 +1803,16 @@ def aberration_correct( ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - if use_FFT_fit: - sin_chi = np.sin(self.calc_CTF(self.alpha,*self.aber_coefs)) + if use_CTF_fit is None: + if hasattr(self, "_aberrations_surface_shape"): + use_CTF_fit = True + + if use_CTF_fit: + sin_chi = np.sin( + self._calculate_CTF( + self._aberrations_surface_shape, *self._aberrations_coefs + ) + ) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 @@ -1748,7 +1832,9 @@ def aberration_correct( if Wiener_filter: SNR_inv = ( xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + 1 + + (kra2**k_info_power) + / ((k_info_limit) ** (2 * k_info_power)) ) / Wiener_signal_noise_ratio ) From de75b517f89be4f1976cc7bdd29dd2908d58bedf Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 19 Oct 2023 15:03:35 -0700 Subject: [PATCH 49/62] small bug fixes --- py4DSTEM/process/phase/iterative_parallax.py | 106 ++++++++++++++----- 1 file changed, 77 insertions(+), 29 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 5f7e1c25e..1a03f4fa3 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1381,38 +1381,47 @@ def aberration_fit( qy = xp.fft.fftfreq(im_FFT.shape[1], sy) qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 - alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None, :], qx[:, None]) + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_FFT = xp.zeros( + (alpha_FFT.size, self._aberrations_num) + ) for a0 in range(self._aberrations_num): m, n, a = self._aberrations_mn[a0] if n == 0: # Radially symmetric basis - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) / (m + 1) + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) / (m + 1) ).ravel() elif a == 0: # cos coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) ).ravel() else: # sin coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) ).ravel() # global scaling - self._aberrations_basis *= 2 * np.pi / self._wavelength - self._aberrations_surface_shape = alpha.shape + self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape_FFT = alpha_FFT.shape plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) - angular_mask = np.cos(8.0 * theta) ** 2 < 0.25 + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # CTF function + def calculate_CTF_FFT(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] + return xp.reshape(chi, alpha_shape) # Direct Shifts Fitting - elif fit_BF_shifts: + if fit_BF_shifts: # FFT coordinates sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) @@ -1476,14 +1485,12 @@ def aberration_fit( self._aberrations_basis *= 2 * np.pi / self._wavelength self._aberrations_surface_shape = alpha.shape - # CTF function - def calculate_CTF(alpha_shape, *coefs): - chi = xp.zeros_like(self._aberrations_basis[:, 0]) - for a0 in range(len(coefs)): - chi += coefs[a0] * self._aberrations_basis[:, a0] - return xp.reshape(chi, alpha_shape) - - self._calculate_CTF = calculate_CTF + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) @@ -1497,7 +1504,7 @@ def calculate_CTF(alpha_shape, *coefs): # scoring function to minimize - mean value of zero crossing regions of FFT def score_CTF(coefs): im_CTF = xp.abs( - self._calculate_CTF(self._aberrations_surface_shape, *coefs) + calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) ) mask = xp.logical_and( im_CTF > 0.5 * np.pi, @@ -1506,7 +1513,9 @@ def score_CTF(coefs): if np.any(mask): weights = xp.cos(im_CTF[mask]) ** 4 return asnumpy( - xp.sum(weights * im_FFT[mask] * alpha[mask] ** fit_power_alpha) + xp.sum( + weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha + ) / xp.sum(weights) ) else: @@ -1579,7 +1588,7 @@ def score_CTF(coefs): # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: # Generate FFT plotting image - im_scale = asnumpy(im_FFT * alpha**fit_power_alpha) + im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) int_vals = np.sort(im_scale.ravel()) int_range = ( int_vals[np.round(0.02 * im_scale.size).astype("int")], @@ -1598,8 +1607,8 @@ def score_CTF(coefs): im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) # Add CTF zero crossings - im_CTF = self._calculate_CTF( - self._aberrations_surface_shape, *self._aberrations_coefs + im_CTF = calculate_CTF_FFT( + self._aberrations_surface_shape_FFT, *self._aberrations_coefs ) im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 @@ -1744,6 +1753,47 @@ def score_CTF(coefs): xp._default_memory_pool.free_all_blocks() xp.clear_memo() + def _calculate_CTF(self, alpha_shape, sampling, *coefs): + xp = self._xp + + # FFT coordinates + sx, sy = sampling + qx = xp.fft.fftfreq(alpha_shape[0], sx) + qy = xp.fft.fftfreq(alpha_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + + # global scaling + aberrations_basis *= 2 * np.pi / self._wavelength + + chi = xp.zeros_like(aberrations_basis[:, 0]) + + for a0 in range(len(coefs)): + chi += coefs[a0] * aberrations_basis[:, a0] + + return xp.reshape(chi, alpha_shape) + def aberration_correct( self, use_CTF_fit=None, @@ -1809,9 +1859,7 @@ def aberration_correct( if use_CTF_fit: sin_chi = np.sin( - self._calculate_CTF( - self._aberrations_surface_shape, *self._aberrations_coefs - ) + self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) ) CTF_corr = xp.sign(sin_chi) From 52ee427f9403839c02eb4a7de83719fa8c19585c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 04:09:47 -0700 Subject: [PATCH 50/62] cleaned up parallax --- py4DSTEM/process/phase/iterative_parallax.py | 110 +++++++++++++++---- 1 file changed, 86 insertions(+), 24 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 1a03f4fa3..34631256c 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -10,6 +10,7 @@ import numpy as np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import DataCube from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction @@ -29,6 +30,23 @@ warnings.simplefilter(action="always", category=UserWarning) +_aberration_names = { + (1, 0): "-defocus ", + (1, 2): "stig ", + (2, 1): "coma ", + (2, 3): "trefoil ", + (3, 0): "Cs ", + (3, 2): "stig2 ", + (3, 4): "quadfoil ", + (4, 1): "coma2 ", + (4, 3): "trefoil2 ", + (4, 5): "pentafoil ", + (5, 0): "C5 ", + (5, 2): "stig3 ", + (5, 4): "quadfoil2 ", + (5, 6): "hexafoil ", +} + class ParallaxReconstruction(PhaseReconstruction): """ @@ -40,9 +58,6 @@ class ParallaxReconstruction(PhaseReconstruction): Input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - dp_mean: ndarray, optional - Mean diffraction pattern - If None, get_dp_mean() is used verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -122,6 +137,7 @@ def to_h5(self, group): if hasattr(self, "aberration_C1"): recon_metadata |= { "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_transpose": self.transpose_detected, "aberration_C1": self.aberration_C1, "aberration_A1x": self.aberration_A1x, "aberration_A1y": self.aberration_A1y, @@ -136,6 +152,15 @@ def to_h5(self, group): data=self._asnumpy(self._recon_BF_subpixel_aligned), ) + if hasattr(self, "aberration_dict"): + self.metadata = Metadata( + name="aberrations_metadata", + data={ + v["common name"]: v["value [Ang]"] + for k, v in self.aberration_dict.items() + }, + ) + self.metadata = Metadata( name="reconstruction_metadata", data=recon_metadata, @@ -212,6 +237,7 @@ def _populate_instance(self, group): if "aberration_C1" in reconstruction_md.keys: self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.transpose_detected = reconstruction_md["aberration_transpose"] self.aberration_C1 = reconstruction_md["aberration_C1"] self.aberration_A1x = reconstruction_md["aberration_A1x"] self.aberration_A1y = reconstruction_md["aberration_A1y"] @@ -327,9 +353,6 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - # center_x = np.mean(com_fitted_x) - # center_y = np.mean(com_fitted_y) - center_x, center_y = self._region_of_interest_shape / 2 for rx in range(intensities_shifted.shape[0]): @@ -706,8 +729,6 @@ def tune_angle_and_defocus( convergence.append(asnumpy(self._recon_error[0])) if plot_convergence: - from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - fig, ax = plt.subplots() ax.set_title("convergence") im = ax.imshow( @@ -1287,19 +1308,29 @@ def aberration_fit( Parameters ---------- + fit_BF_shifts: bool + Set to True to fit aberrations to the measured BF shifts directly. fit_CTF_FFT: bool - Set to True to directly fit aberrations in the FFT of the upsampled BF - image (if available). Note that this method relies on visible zero - crossings in the FFT, and will not work if they are not present. - fit_upsampled_FFT: bool - If True, we aberration fit is performed on the upsampled BF image. - This option does nothing if fit_thon_rings is not True. - fit_aber_order_max: int + Set to True to fit aberrations in the FFT of the (upsampled) BF + image. Note that this method relies on visible zero crossings in the FFT. + fit_aberrations_max_radial_order: int Max radial order for fitting of aberrations. - ctf_threshold: float - CTF fitting minimizes value at CTF zero crossings (Thon ring minima). - plot_CTF_compare: bool, optional + fit_aberrations_max_angular_order: int + Max angular order for fitting of aberrations. + fit_aberrations_min_radial_order: int + Min radial order for fitting of aberrations. + fit_aberrations_min_angular_order: int + Min angular order for fitting of aberrations. + fit_max_thon_rings: int + Max number of Thon rings to search for during CTF FFT fitting. + fit_power_alpha: int + Power to raise FFT alpha weighting during CTF FFT fitting. + plot_CTF_comparison: bool, optional If True, the fitted CTF is plotted against the reconstructed frequencies. + plot_BF_shifts_comparison: bool, optional + If True, the measured vs fitted BF shifts are plotted. + upsampled: bool + If True, and upsampled BF is available, uses that for CTF FFT fitting. """ xp = self._xp @@ -1328,6 +1359,7 @@ def aberration_fit( self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + self.transpose_detected = False ### Second pass @@ -1353,6 +1385,9 @@ def aberration_fit( self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ np.argsort(self._aberrations_mn[sub, 0]), : ] + self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][ + np.argsort(self._aberrations_mn[~sub, 0]), : + ] self._aberrations_num = self._aberrations_mn.shape[0] if plot_CTF_comparison is None: @@ -1579,8 +1614,18 @@ def score_CTF(coefs): )[:2] if res_T.sum() < res.sum(): + self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T + self.transpose_detected = True self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T + + warnings.warn( + ( + "Data transpose detected. " + f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" + ), + UserWarning, + ) else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts @@ -1695,6 +1740,16 @@ def score_CTF(coefs): ], ) + self.aberration_dict = { + tuple(self._aberrations_mn[a0]): { + "common name": _aberration_names.get( + tuple(self._aberrations_mn[a0, :2]), "-" + ).strip(), + "value [Ang]": self._aberrations_coefs[a0], + } + for a0 in range(self._aberrations_num) + } + # Print results if self._verbose: if fit_CTF_FFT or fit_BF_shifts: @@ -1720,21 +1775,26 @@ def score_CTF(coefs): print() print("Refined Aberration coefficients") print("-------------------------------") - print("radial angular dir. coefs") - print("order order ") - print("------ ------- ---- -----") + print("common radial angular dir. coefs") + print("name order order Ang ") + print("---------- ------- ------- ---- -----") for a0 in range(self._aberrations_mn.shape[0]): m, n, a = self._aberrations_mn[a0] + name = _aberration_names.get((m, n), " -- ") if n == 0: print( - str(m) + name + + " " + + str(m) + " 0 - " + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) elif a == 0: print( - str(m) + name + + " " + + str(m) + " " + str(n) + " x " @@ -1742,7 +1802,9 @@ def score_CTF(coefs): ) else: print( - str(m) + name + + " " + + str(m) + " " + str(n) + " y " From 54d5859f9c58c78d2dc97aa75728081ac04d02f5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 07:15:51 -0700 Subject: [PATCH 51/62] ptycho new aberration formalism --- .../iterative_ptychographic_constraints.py | 2 + py4DSTEM/process/phase/utils.py | 80 ++++++++++++++----- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 4721ed12b..3eebdb068 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -555,10 +555,12 @@ def _probe_aberration_fitting_constraint( fourier_probe = xp.fft.fft2(current_probe) fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling + energy = self._energy fitted_angle, _ = fit_aberration_surface( fourier_probe, sampling, + energy, max_angular_order, max_radial_order, xp=xp, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index d06db111c..cc5fa8cb4 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1543,39 +1543,75 @@ def step_model(radius, sig_0, rad_0, width): def aberrations_basis_function( probe_size, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ + # mn = [[0,0,0]] + mn = [] + + for m in range(1, max_radial_order + 1): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(0, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + aberrations_mn = np.array(mn) + aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :] + + sub = aberrations_mn[:, 1] > 0 + aberrations_mn[sub, :] = aberrations_mn[sub, :][ + np.argsort(aberrations_mn[sub, 0]), : + ] + aberrations_mn[~sub, :] = aberrations_mn[~sub, :][ + np.argsort(aberrations_mn[~sub, 0]), : + ] + aberrations_num = aberrations_mn.shape[0] + sx, sy = probe_size dx, dy = probe_sampling + wavelength = electron_wavelength_angstrom(energy) + qx = xp.fft.fftfreq(sx, dx) qy = xp.fft.fftfreq(sy, dy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, aberrations_num)) + + for a0 in range(aberrations_num): + m, n, a = aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - qxa, qya = xp.meshgrid(qx, qy, indexing="ij") - q2 = qxa**2 + qya**2 - theta = xp.arctan2(qya, qxa) - - basis = [] - index = [] - - for n in range(max_angular_order + 1): - for m in range((max_radial_order - n) // 2 + 1): - basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) - index.append((m, n, 0)) - if n > 0: - basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) - index.append((m, n, 1)) - - basis = xp.array(basis) + # global scaling + aberrations_basis *= 2 * np.pi / wavelength - return basis, index + return aberrations_basis, aberrations_mn def fit_aberration_surface( complex_probe, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, @@ -1592,22 +1628,22 @@ def fit_aberration_surface( unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) - basis, _ = aberrations_basis_function( + raveled_basis, _ = aberrations_basis_function( complex_probe.shape, probe_sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - raveled_basis = basis.reshape((basis.shape[0], -1)) raveled_weights = probe_amp.ravel() - Aw = raveled_basis.T * raveled_weights[:, None] + Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights - coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] + coeff = -xp.linalg.lstsq(Aw, bw, rcond=None)[0] - fitted_angle = xp.tensordot(coeff, basis, axes=1) + fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) return fitted_angle, coeff From 9865f39a8475dc155a30fc6c9f1faafff6111ccf Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 07:58:47 -0700 Subject: [PATCH 52/62] adding chroma_boost defaults --- .../process/phase/iterative_base_class.py | 2 ++ ...tive_mixedstate_multislice_ptychography.py | 28 ++++++++++++++++--- .../iterative_mixedstate_ptychography.py | 28 +++++++++++++++---- .../iterative_multislice_ptychography.py | 25 ++++++++++++++--- .../iterative_overlap_magnetic_tomography.py | 5 ++++ .../phase/iterative_overlap_tomography.py | 21 ++++++++++++++ .../iterative_simultaneous_ptychography.py | 13 +++++++-- .../iterative_singleslice_ptychography.py | 21 ++++++++++++-- 8 files changed, 126 insertions(+), 17 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 62cf3a3a1..4dced4291 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2208,6 +2208,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 2) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2218,6 +2219,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost = chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 306f47f77..2747fe601 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -613,11 +613,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered[0], power=2, + chroma_boost = chroma_boost, ) # propagated @@ -630,6 +632,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -657,6 +660,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -669,7 +673,7 @@ def preprocess( divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax2) + add_colorbar_arg(cax2, chroma_boost=chroma_boost) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") ax2.set_title("Propagated probe[0] intensity") @@ -2502,6 +2506,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + if self._object_type == "complex": obj = np.angle(self.object) else: @@ -2595,12 +2604,12 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = Complex2RGB(self.probe_fourier[0]) + probe_array = Complex2RGB(self.probe_fourier[0],chroma_boost=chroma_boost) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe[0], power=2) + probe_array = Complex2RGB(self.probe[0], power=2, chroma_boost=chroma_boost) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2613,7 +2622,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2722,6 +2731,11 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + errors = np.array(self.error_iterations) objects = [] @@ -2825,6 +2839,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2833,6 +2848,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0]") ax.set_ylabel("x [A]") @@ -2846,6 +2862,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: @@ -2953,12 +2970,15 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost = chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 658079c3e..07b1fe9aa 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -510,11 +510,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -544,7 +546,7 @@ def preprocess( divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax) + add_colorbar_arg(cax, chroma_boost = chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), @@ -1847,6 +1849,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + if self._object_type == "complex": obj = np.angle(self.object) else: @@ -1939,6 +1946,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier[0], + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -1947,6 +1955,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe[0], power=2, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") @@ -1960,7 +1969,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost = chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2068,8 +2077,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2172,6 +2184,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2180,6 +2193,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") @@ -2192,7 +2206,8 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: @@ -2301,11 +2316,14 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost = chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 382efedcd..823c71ca0 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -608,11 +608,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) # propagated @@ -625,6 +627,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -650,7 +653,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -664,6 +667,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") @@ -2404,6 +2408,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2500,12 +2509,13 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2) + probe_array = Complex2RGB(self.probe, power=2, chroma_boost = chroma_boost) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2518,7 +2528,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost = chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2627,6 +2637,11 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + errors = np.array(self.error_iterations) objects = [] @@ -2730,12 +2745,13 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(probes[grid_range[n]], power=2) + probe_array = Complex2RGB(probes[grid_range[n]], power=2, chroma_boost = chroma_boost) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2748,6 +2764,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 459b0ae8c..5035ced81 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -813,11 +813,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) # propagated @@ -830,6 +832,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -857,6 +860,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -871,6 +875,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, + chroma_boost = chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index bb3ee09c2..193c4f5eb 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -753,11 +753,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) # propagated @@ -770,6 +772,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -797,6 +800,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -811,6 +815,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, + chroma_boost = chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") @@ -2585,6 +2590,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + asnumpy = self._asnumpy if projection_angle_deg is not None: @@ -2686,6 +2696,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2694,6 +2705,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -2710,6 +2722,7 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( ax_cb, + chroma_boost = chroma_boost, ) else: ax = fig.add_subplot(spec[0]) @@ -2827,6 +2840,11 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + errors = np.array(self.error_iterations) if projection_angle_deg is not None: @@ -2940,6 +2958,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2948,6 +2967,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2961,6 +2981,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 084a6fcb8..8af804325 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -753,11 +753,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -785,6 +787,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -3078,6 +3081,11 @@ def _visualize_last_iteration( vmax_e = kwargs.pop("vmax_e", max_e) vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) extent = [ 0, @@ -3184,12 +3192,13 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2) + probe_array = Complex2RGB(self.probe, power=2,chroma_boost=chroma_boost) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -3202,7 +3211,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) else: # Electrostatic Object diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0dc2cd053..26020e8de 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -480,11 +480,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -510,7 +512,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1) + add_colorbar_arg(cax1,chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -1757,6 +1759,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + if self._object_type == "complex": obj = np.angle(self.object) else: @@ -1849,6 +1856,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -1857,6 +1865,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -1870,7 +1879,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1978,6 +1987,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2081,6 +2095,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2090,6 +2105,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2103,6 +2119,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: From 020e170b57428fd48f6704dfd68a274d83f6f952 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 08:01:18 -0700 Subject: [PATCH 53/62] formatted, linted, isorted --- .../process/phase/iterative_base_class.py | 2 +- ...tive_mixedstate_multislice_ptychography.py | 26 +++++++++++-------- .../iterative_mixedstate_ptychography.py | 22 ++++++++-------- .../iterative_multislice_ptychography.py | 22 +++++++++------- .../iterative_overlap_magnetic_tomography.py | 8 +++--- .../phase/iterative_overlap_tomography.py | 20 +++++++------- py4DSTEM/process/phase/iterative_parallax.py | 5 ++-- .../iterative_simultaneous_ptychography.py | 24 +++++++---------- .../iterative_singleslice_ptychography.py | 18 ++++++------- 9 files changed, 75 insertions(+), 72 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 4dced4291..6e02dd598 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2219,7 +2219,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2747fe601..fca48b38c 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -619,7 +619,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered[0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -632,7 +632,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -660,7 +660,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -2604,12 +2604,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = Complex2RGB(self.probe_fourier[0],chroma_boost=chroma_boost) + probe_array = Complex2RGB( + self.probe_fourier[0], chroma_boost=chroma_boost + ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe[0], power=2, chroma_boost=chroma_boost) + probe_array = Complex2RGB( + self.probe[0], power=2, chroma_boost=chroma_boost + ) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2622,7 +2626,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2839,7 +2843,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2848,7 +2852,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0]") ax.set_ylabel("x [A]") @@ -2862,7 +2866,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2971,14 +2975,14 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" chroma_boost = kwargs.pop("chroma_boost", 2) - + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 07b1fe9aa..21a29b0b1 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -516,7 +516,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -546,7 +546,7 @@ def preprocess( divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax, chroma_boost = chroma_boost) + add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), @@ -1946,7 +1946,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier[0], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -1955,7 +1955,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe[0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") @@ -1969,7 +1969,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost = chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2077,7 +2077,7 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -2184,7 +2184,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2193,7 +2193,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") @@ -2207,7 +2207,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2317,13 +2317,13 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" chroma_boost = kwargs.pop("chroma_boost", 2) - + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 823c71ca0..a22fad715 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -614,7 +614,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -627,7 +627,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -2408,7 +2408,7 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -2509,13 +2509,15 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2, chroma_boost = chroma_boost) + probe_array = Complex2RGB( + self.probe, power=2, chroma_boost=chroma_boost + ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2528,7 +2530,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost = chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2745,13 +2747,15 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(probes[grid_range[n]], power=2, chroma_boost = chroma_boost) + probe_array = Complex2RGB( + probes[grid_range[n]], power=2, chroma_boost=chroma_boost + ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2764,7 +2768,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 5035ced81..af665baac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -819,7 +819,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -832,7 +832,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -860,7 +860,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -875,7 +875,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 193c4f5eb..d1c323a5d 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -759,7 +759,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -772,7 +772,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -800,7 +800,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -815,7 +815,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") @@ -2696,7 +2696,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2705,7 +2705,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -2722,7 +2722,7 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( ax_cb, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) else: ax = fig.add_subplot(spec[0]) @@ -2958,7 +2958,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2967,7 +2967,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2981,7 +2981,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 34631256c..828fd12a2 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -11,7 +11,7 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable -from py4DSTEM import DataCube +from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform @@ -19,9 +19,8 @@ from py4DSTEM.process.utils.utils import electron_wavelength_angstrom from py4DSTEM.visualize import show from scipy.linalg import polar +from scipy.optimize import minimize from scipy.special import comb -from scipy.optimize import curve_fit, minimize -from scipy.signal import medfilt2d try: import cupy as cp diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8af804325..eb900d5d0 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -485,10 +485,7 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, - crop_patterns + intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns ) # explicitly delete namescapes @@ -570,10 +567,7 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, - crop_patterns + intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns ) # explicitly delete namescapes @@ -759,7 +753,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -787,7 +781,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -3081,7 +3075,7 @@ def _visualize_last_iteration( vmax_e = kwargs.pop("vmax_e", max_e) vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -3192,13 +3186,15 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2,chroma_boost=chroma_boost) + probe_array = Complex2RGB( + self.probe, power=2, chroma_boost=chroma_boost + ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -3211,7 +3207,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: # Electrostatic Object diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 26020e8de..547117f8d 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -486,7 +486,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -512,7 +512,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1,chroma_boost=chroma_boost) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -1856,7 +1856,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -1865,7 +1865,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -1879,7 +1879,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1987,7 +1987,7 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -2095,7 +2095,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2105,7 +2105,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2119,7 +2119,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: From 9806e277b7892f1a3eb8f2b54f390b0c1e326cf5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 16:15:08 -0700 Subject: [PATCH 54/62] fixing radial order accounting --- py4DSTEM/process/phase/iterative_parallax.py | 18 +++++++++--------- py4DSTEM/process/phase/utils.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 8b2b007f5..716b84342 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -34,7 +34,7 @@ (1, 2): "stig ", (2, 1): "coma ", (2, 3): "trefoil ", - (3, 0): "Cs ", + (3, 0): "C3 ", (3, 2): "stig2 ", (3, 4): "quadfoil ", (4, 1): "coma2 ", @@ -155,7 +155,7 @@ def to_h5(self, group): self.metadata = Metadata( name="aberrations_metadata", data={ - v["common name"]: v["value [Ang]"] + v["aberration name"]: v["value [Ang]"] for k, v in self.aberration_dict.items() }, ) @@ -1294,7 +1294,7 @@ def aberration_fit( fit_CTF_FFT: bool = False, fit_aberrations_max_radial_order: int = 3, fit_aberrations_max_angular_order: int = 4, - fit_aberrations_min_radial_order: int = 1, + fit_aberrations_min_radial_order: int = 2, fit_aberrations_min_angular_order: int = 0, fit_max_thon_rings: int = 6, fit_power_alpha: float = 2.0, @@ -1366,7 +1366,7 @@ def aberration_fit( mn = [] for m in range( - fit_aberrations_min_radial_order, fit_aberrations_max_radial_order + 1 + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order ): n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) for n in range(fit_aberrations_min_angular_order, n_max + 1): @@ -1741,7 +1741,7 @@ def score_CTF(coefs): self.aberration_dict = { tuple(self._aberrations_mn[a0]): { - "common name": _aberration_names.get( + "aberration name": _aberration_names.get( tuple(self._aberrations_mn[a0, :2]), "-" ).strip(), "value [Ang]": self._aberrations_coefs[a0], @@ -1774,7 +1774,7 @@ def score_CTF(coefs): print() print("Refined Aberration coefficients") print("-------------------------------") - print("common radial angular dir. coefs") + print("aberration radial angular dir. coefs") print("name order order Ang ") print("---------- ------- ------- ---- -----") @@ -1785,7 +1785,7 @@ def score_CTF(coefs): print( name + " " - + str(m) + + str(m + 1) + " 0 - " + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) @@ -1793,7 +1793,7 @@ def score_CTF(coefs): print( name + " " - + str(m) + + str(m + 1) + " " + str(n) + " x " @@ -1803,7 +1803,7 @@ def score_CTF(coefs): print( name + " " - + str(m) + + str(m + 1) + " " + str(n) + " y " diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index cc5fa8cb4..7e348826e 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1552,7 +1552,7 @@ def aberrations_basis_function( # mn = [[0,0,0]] mn = [] - for m in range(1, max_radial_order + 1): + for m in range(1, max_radial_order): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: From 2eca4b7eb1d5f74947148cd87f745af0d05e9c38 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 20 Oct 2023 16:48:36 -0700 Subject: [PATCH 55/62] make lint happy I hope! --- py4DSTEM/visualize/vis_special.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 19e0c5c7a..cfa017299 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -18,6 +18,7 @@ from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR from colorspacious import cspace_convert + def show_elliptical_fit( ar, fitradii, From 71b8f6d9039bce7d5177341b74516c62cd475b5a Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 20 Oct 2023 17:28:32 -0700 Subject: [PATCH 56/62] fix extent for ms depth sectioning --- .../process/phase/iterative_multislice_ptychography.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 576b431bf..4515590fe 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3095,9 +3095,9 @@ def show_depth( self._slice_thicknesses[0] * plot_im.shape[0], 0, ] - + figsize = kwargs.pop("figsize", (6, 6)) if not plot_line_profile: - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=figsize) im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax.set_aspect(aspect) @@ -3112,11 +3112,12 @@ def show_depth( else: extent2 = [ 0, - self.sampling[0] * ms_obj.shape[1], self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], 0, ] - fig, ax = plt.subplots(2, 1) + + fig, ax = plt.subplots(2, 1, figsize=figsize) ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) ax[0].plot( [y1 * self.sampling[0], y2 * self.sampling[1]], From dfc312b190b8144655a1f49005aacbfb3510684e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 21 Oct 2023 10:28:32 -0700 Subject: [PATCH 57/62] small fixes --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- .../phase/iterative_mixedstate_multislice_ptychography.py | 7 ++++--- .../process/phase/iterative_overlap_magnetic_tomography.py | 2 +- py4DSTEM/process/phase/iterative_overlap_tomography.py | 2 +- py4DSTEM/process/phase/iterative_parallax.py | 1 + 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 66dc3a8d6..906e9add1 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2243,7 +2243,7 @@ def show_object_fft(self, obj=None, **kwargs): vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index fca48b38c..3eeb07814 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3237,8 +3237,9 @@ def show_depth( 0, ] + figsize = kwargs.pop("figsize", (6, 6)) if not plot_line_profile: - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=figsize) im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax.set_aspect(aspect) @@ -3253,11 +3254,11 @@ def show_depth( else: extent2 = [ 0, - self.sampling[0] * ms_obj.shape[1], self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], 0, ] - fig, ax = plt.subplots(2, 1) + fig, ax = plt.subplots(2, 1, figsize=figsize) ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) ax[0].plot( [y1 * self.sampling[0], y2 * self.sampling[1]], diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 57e42a366..32b0f6fd4 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3292,7 +3292,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index ab37dfad5..66cf46487 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3172,7 +3172,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 716b84342..a69dece3b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1874,6 +1874,7 @@ def aberration_correct( ---------- use_FFT_fit: bool Use the CTF fitted to the zero crossings of the FFT. + Default is True plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional From b0e2c4244b1241505315b5a7bd2e555fff4b07d2 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 21 Oct 2023 11:46:35 -0700 Subject: [PATCH 58/62] fix for ptycho aberration fit --- py4DSTEM/process/phase/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 7e348826e..93428f5bb 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1549,10 +1549,10 @@ def aberrations_basis_function( xp=np, ): """ """ - # mn = [[0,0,0]] - mn = [] + mn = [[0,0,0]] + # mn = [] - for m in range(1, max_radial_order): + for m in range(max_radial_order+1): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: @@ -1583,9 +1583,9 @@ def aberrations_basis_function( theta = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - aberrations_basis = xp.zeros((alpha.size, aberrations_num)) + aberrations_basis = xp.ones((alpha.size, aberrations_num)) - for a0 in range(aberrations_num): + for a0 in range(1,aberrations_num): m, n, a = aberrations_mn[a0] if n == 0: # Radially symmetric basis @@ -1641,7 +1641,7 @@ def fit_aberration_surface( Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights - coeff = -xp.linalg.lstsq(Aw, bw, rcond=None)[0] + coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) From 17dd9a2ef212fadf8f282b97ba5934fea21cc042 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 21 Oct 2023 11:48:32 -0700 Subject: [PATCH 59/62] black format --- py4DSTEM/process/phase/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 93428f5bb..374e3fc15 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1549,10 +1549,10 @@ def aberrations_basis_function( xp=np, ): """ """ - mn = [[0,0,0]] + mn = [[0, 0, 0]] # mn = [] - for m in range(max_radial_order+1): + for m in range(max_radial_order + 1): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: @@ -1585,7 +1585,7 @@ def aberrations_basis_function( # Aberration basis aberrations_basis = xp.ones((alpha.size, aberrations_num)) - for a0 in range(1,aberrations_num): + for a0 in range(1, aberrations_num): m, n, a = aberrations_mn[a0] if n == 0: # Radially symmetric basis From ada4d4d8a2149b35c4ef03011477a024ee4a7ced Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 21 Oct 2023 17:17:10 -0700 Subject: [PATCH 60/62] fixed ptycho fitting, added transpose flag in parallax --- py4DSTEM/process/phase/iterative_parallax.py | 98 ++++++++++++-------- py4DSTEM/process/phase/utils.py | 8 +- 2 files changed, 65 insertions(+), 41 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index a69dece3b..dcfd8f504 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -30,7 +30,7 @@ warnings.simplefilter(action="always", category=UserWarning) _aberration_names = { - (1, 0): "-defocus ", + (1, 0): "C1 ", (1, 2): "stig ", (2, 1): "coma ", (2, 3): "trefoil ", @@ -1290,7 +1290,7 @@ def subpixel_alignment( def aberration_fit( self, - fit_BF_shifts: bool = True, + fit_BF_shifts: bool = False, fit_CTF_FFT: bool = False, fit_aberrations_max_radial_order: int = 3, fit_aberrations_max_angular_order: int = 4, @@ -1301,6 +1301,7 @@ def aberration_fit( plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, upsampled: bool = True, + force_transpose: bool = None, ): """ Fit aberrations to the measured image shifts. @@ -1330,6 +1331,8 @@ def aberration_fit( If True, the measured vs fitted BF shifts are plotted. upsampled: bool If True, and upsampled BF is available, uses that for CTF FFT fitting. + force_transpose: bool + If True, and fit_BF_shifts is True, flips the measured x and y shifts """ xp = self._xp @@ -1358,7 +1361,11 @@ def aberration_fit( self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - self.transpose_detected = False + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose ### Second pass @@ -1583,48 +1590,63 @@ def score_CTF(coefs): ) ) - # Untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None - )[:2] - - # Transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) - m_T = asnumpy( - xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0] - ) - m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2( - m_rotation_T[1, 0], m_rotation_T[0, 0] - ) - if np.abs(np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi) > ( - np.pi * 0.5 - ): - rotation_Q_to_R_rads_T = ( - np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + if force_transpose is None or force_transpose is True: + # Transposed fit + transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) + m_T = asnumpy( + xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ + 0 + ] ) + m_rotation_T, _ = polar(m_T, side="right") + rotation_Q_to_R_rads_T = -1 * np.arctan2( + m_rotation_T[1, 0], m_rotation_T[0, 0] + ) + if np.abs( + np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi + ) > (np.pi * 0.5): + rotation_Q_to_R_rads_T = ( + np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + ) - tf = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf(transposed_shifts, xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq( - gradients, rotated_shifts_T, rcond=None - )[:2] + tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) + rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq( + gradients, rotated_shifts_T, rcond=None + )[:2] + + if force_transpose is None or force_transpose is False: + # Untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # Compare fits + if res_T.sum() < res.sum(): + self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T + self.transpose_detected = True + self._aberrations_coefs = asnumpy(aberrations_coefs_T) + self._rotated_shifts = rotated_shifts_T + + warnings.warn( + ( + "Data transpose detected. " + f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" + ), + UserWarning, + ) + else: + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts - if res_T.sum() < res.sum(): + elif force_transpose is True: self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = True self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T - warnings.warn( - ( - "Data transpose detected. " - f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" - ), - UserWarning, - ) else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 374e3fc15..d29765d04 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1549,10 +1549,11 @@ def aberrations_basis_function( xp=np, ): """ """ - mn = [[0, 0, 0]] - # mn = [] - for m in range(max_radial_order + 1): + # Add constant phase shift in basis + mn = [[-1, 0, 0]] + + for m in range(1, max_radial_order): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: @@ -1585,6 +1586,7 @@ def aberrations_basis_function( # Aberration basis aberrations_basis = xp.ones((alpha.size, aberrations_num)) + # Skip constant to avoid dividing by zero in normalization for a0 in range(1, aberrations_num): m, n, a = aberrations_mn[a0] if n == 0: From 9529945ee4f899277b6338a5e9de484235e88b70 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 22 Oct 2023 11:39:59 -0700 Subject: [PATCH 61/62] added force_transpose option for other two aberration fit methods --- py4DSTEM/process/phase/iterative_parallax.py | 49 +++++++++----------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index dcfd8f504..6ebb9962e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1341,7 +1341,18 @@ def aberration_fit( ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) # Solve affine transformation m = asnumpy( @@ -1362,11 +1373,6 @@ def aberration_fit( self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if force_transpose is None: - self.transpose_detected = False - else: - self.transpose_detected = force_transpose - ### Second pass # Aberration coefs @@ -1590,8 +1596,15 @@ def score_CTF(coefs): ) ) - if force_transpose is None or force_transpose is True: - # Transposed fit + # (Relative) untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # (Relative) transposed fit transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) m_T = asnumpy( xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ @@ -1615,19 +1628,10 @@ def score_CTF(coefs): gradients, rotated_shifts_T, rcond=None )[:2] - if force_transpose is None or force_transpose is False: - # Untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None - )[:2] - - if force_transpose is None: # Compare fits if res_T.sum() < res.sum(): self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = True + self.transpose_detected = not self.transpose_detected self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T @@ -1638,15 +1642,6 @@ def score_CTF(coefs): ), UserWarning, ) - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts - - elif force_transpose is True: - self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self._aberrations_coefs = asnumpy(aberrations_coefs_T) - self._rotated_shifts = rotated_shifts_T - else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts From 43220d0b68705ec8d59ad94b7ff68446d7738255 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 22 Oct 2023 12:47:16 -0700 Subject: [PATCH 62/62] read-write device bugfix --- .../process/phase/iterative_base_class.py | 51 ++++++++++++++++++- py4DSTEM/process/phase/iterative_dpc.py | 4 +- py4DSTEM/process/phase/iterative_parallax.py | 4 +- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 906e9add1..04cfd6a60 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -56,6 +56,53 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self + def reinitialize_parameters(self, device: str = None, verbose: bool = None): + """ + Reinitializes common parameters. This is useful when loading a previously-saved + reconstruction (which set device='cpu' and verbose=True for compatibility) , + using different initialization parameters. + + Parameters + ---------- + device: str, optional + If not None, imports and assigns appropriate device modules + verbose: bool, optional + If not None, sets the verbosity to verbose + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device is not None: + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device + + if verbose is not None: + self._verbose = verbose + + return self + def set_save_defaults( self, save_datacube: bool = False, @@ -1408,10 +1455,10 @@ def _get_constructor_args(cls, group): "object_type": instance_md["object_type"], "semiangle_cutoff": instance_md["semiangle_cutoff"], "rolloff": instance_md["rolloff"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], "polar_parameters": polar_params, + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } class_specific_kwargs = {} diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4ca2c170f..af3cbbb45 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -195,9 +195,9 @@ def _get_constructor_args(cls, group): "datacube": dc, "initial_object_guess": np.asarray(obj), "energy": instance_md["energy"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 6ebb9962e..74688fa0b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -206,10 +206,10 @@ def _get_constructor_args(cls, group): kwargs = { "datacube": dc, "energy": instance_md["energy"], - "verbose": instance_md["verbose"], - "device": instance_md["device"], "object_padding_px": instance_md["object_padding_px"], "name": instance_md["name"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs