From 610612d51acb07bf9895b784d6ebee64d8cd1c6f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 29 Dec 2023 20:21:58 -0800 Subject: [PATCH 001/128] DCT-based poisson solver phase unwrapping Former-commit-id: f7a2deb40b2539729d646d605632bdca95fd3516 --- py4DSTEM/process/phase/utils.py | 73 ++++++++++++++++++++++++++++----- 1 file changed, 63 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 0aaa67a9f..31e1aac65 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -7,16 +7,15 @@ try: import cupy as cp - from cupyx.scipy.fft import rfft + from cupyx.scipy.fft import dctn, idctn, rfft except (ImportError, ModuleNotFoundError): cp = None - from scipy.fft import dstn, idstn + from scipy.fft import dstn, idstn, dctn, idctn from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom from scipy.ndimage import gaussian_filter, uniform_filter1d -from skimage.restoration import unwrap_phase # fmt: off @@ -1611,6 +1610,61 @@ def aberrations_basis_function( return aberrations_basis, aberrations_mn +def preconditioned_laplacian_dct(shape, xp=np): + """DCT eigenvalues""" + n, m = shape + i, j = xp.ogrid[0:n, 0:m] + + op = 4 - 2 * xp.cos(np.pi * i / n) - 2 * xp.cos(np.pi * j / m) + op[0, 0] = 1 # gauge invariance + return -op + + +def preconditioned_poisson_solver_dct(rhs, gauge=None, xp=np): + """DCT based poisson solver""" + op = preconditioned_laplacian_dct(rhs.shape, xp=xp) + + if gauge is None: + gauge = xp.mean(rhs) + + fft_rhs = dctn(rhs, type=2) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idctn(fft_rhs / op, type=2) + return sol + + +def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np): + """Weigted phase unwrapping using DCT-based poisson solver""" + + if np.iscomplexobj(array): + raise ValueError() + + if corner_centered: + array = xp.fft.fftshift(array) + if weights is not None: + weights = xp.fft.fftshift(weights) + + dx = xp.mod(xp.diff(array, axis=0) + np.pi, 2 * np.pi) - np.pi + dy = xp.mod(xp.diff(array, axis=1) + np.pi, 2 * np.pi) - np.pi + + if weights is not None: + ww = weights**2 + dx *= xp.minimum(ww[:-1, :], ww[1:, :]) + dy *= xp.minimum(ww[:, :-1], ww[:, 1:]) + + rho = xp.diff(dx, axis=0, prepend=0, append=0) + xp.diff( + dy, axis=1, prepend=0, append=0 + ) + + unwrapped_array = preconditioned_poisson_solver_dct(rho, gauge=gauge, xp=xp).real + unwrapped_array -= unwrapped_array.min() + + if corner_centered: + unwrapped_array = xp.fft.ifftshift(unwrapped_array) + + return unwrapped_array + + def fit_aberration_surface( complex_probe, probe_sampling, @@ -1623,13 +1677,12 @@ def fit_aberration_surface( probe_amp = xp.abs(complex_probe) probe_angle = -xp.angle(complex_probe) - if xp is np: - probe_angle = probe_angle.astype(np.float64) - unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True).astype(xp.float32) - else: - probe_angle = xp.asnumpy(probe_angle).astype(np.float64) - unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) - unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) + unwrapped_angle = unwrap_phase_2d( + probe_angle, + weights=probe_amp, + corner_centered=True, + xp=xp, + ) raveled_basis, _ = aberrations_basis_function( complex_probe.shape, From 7b44e664d77b0c293d528e9a664775af5cecb2a0 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 29 Dec 2023 21:29:55 -0800 Subject: [PATCH 002/128] constraints refactoring Former-commit-id: c85faa8b79ca81e35b72e86e01da16229208970e --- .../process/phase/iterative_base_class.py | 5 +- ...tive_mixedstate_multislice_ptychography.py | 312 +------------ .../iterative_mixedstate_ptychography.py | 76 +-- .../iterative_multislice_ptychography.py | 248 +--------- .../iterative_overlap_magnetic_tomography.py | 198 +------- .../phase/iterative_overlap_tomography.py | 190 +------- .../iterative_ptychographic_constraints.py | 439 ++++++++++++++++-- .../iterative_simultaneous_ptychography.py | 12 +- .../iterative_singleslice_ptychography.py | 12 +- 9 files changed, 528 insertions(+), 964 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index a0ed485ba..e82816042 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -20,9 +20,6 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( - PtychographicConstraints, -) from py4DSTEM.process.phase.utils import ( AffineTransform, generate_batches, @@ -1293,7 +1290,7 @@ def show_complex_CoM( ) -class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): +class PtychographicReconstruction(PhaseReconstruction): """ Base ptychographic reconstruction class. Inherits from PhaseReconstruction and PtychographicConstraints. diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 10dc40e00..0f5b79e29 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -16,14 +16,17 @@ import cupy as cp except (ModuleNotFoundError, ImportError): cp = None - import os - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ProbeMixedConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -38,7 +41,14 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): +class MixedstateMultislicePtychographicReconstruction( + PositionsConstraintsMixin, + ProbeMixedConstraintsMixin, + ProbeConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Mixed-State Multislice Ptychographic Reconstruction Class. @@ -1473,281 +1483,6 @@ def _position_correction( return current_positions - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - probe_intensity = xp.abs(current_probe[0]) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_orthogonalization_constraint(self, current_probe): - """ - Ptychographic probe-orthogonalization constraint. - Used to ensure mixed states are orthogonal to each other. - Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Orthogonalized probe estimate - """ - xp = self._xp - n_probes = self._num_probes - - # compute upper half of P* @ P - pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) - - for i in range(n_probes): - for j in range(i, n_probes): - pairwise_dot_product[i, j] = xp.sum( - current_probe[i].conj() * current_probe[j] - ) - - # compute eigenvectors (effectively cheaper way of computing V* from SVD) - _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") - current_probe = xp.tensordot(evecs.T, current_probe, axes=1) - - # sort by real-space intensity - intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) - intensities_order = xp.argsort(intensities, axis=None)[::-1] - return current_probe[intensities_order] - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - 2D Butterworth filter - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_kz_regularization_constraint( - self, current_object, kz_regularization_gamma - ): - """ - Arctan regularization filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - kz_regularization_gamma: float - Slice regularization strength - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - current_object = xp.pad( - current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" - ) - - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - - kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] - - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qz2 = qza**2 * kz_regularization_gamma**2 - qr2 = qxa**2 + qya**2 - - w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) - - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) - current_object = current_object[1:] - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_identical_slices_constraint(self, current_object): - """ - Strong regularization forcing all slices to be identical - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - object_mean = current_object.mean(0, keepdims=True) - current_object[:] = object_mean - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - def _constraints( self, current_object, @@ -1855,8 +1590,8 @@ def _constraints( If True, performs TV denoising along z tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising + tv_denoise_pad_chambolle: int + If not None, pads object at top and bottom with this many zeros before applying denoising tv_denoise: bool If True, applies TV denoising on object tv_denoise_weights: [float,float] @@ -1894,20 +1629,23 @@ def _constraints( current_object = self._object_identical_slices_constraint(current_object) elif kz_regularization_filter: current_object = self._object_kz_regularization_constraint( - current_object, kz_regularization_gamma + current_object, + kz_regularization_gamma, + z_padding=1, ) elif tv_denoise: current_object = self._object_denoise_tv_pylops( current_object, tv_denoise_weights, tv_denoise_inner_iter, + z_padding=1, ) elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad_chambolle, + padding=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1996,7 +1734,7 @@ def reconstruct( pure_phase_object_iter: int = 0, tv_denoise_iter_chambolle=np.inf, tv_denoise_weight_chambolle=None, - tv_denoise_pad_chambolle=True, + tv_denoise_pad_chambolle=1, tv_denoise_iter=np.inf, tv_denoise_weights=None, tv_denoise_inner_iter=40, @@ -2096,8 +1834,8 @@ def reconstruct( Number of iterations with TV denoisining tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising + tv_denoise_pad_chambolle: int + If not None, pads object at top and bottom with this many zeros before applying denoising tv_denoise: bool If True, applies TV denoising on object tv_denoise_weights: [float,float] diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 880858f30..2809f0144 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -20,6 +20,12 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ProbeMixedConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -32,7 +38,13 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstatePtychographicReconstruction(PtychographicReconstruction): +class MixedstatePtychographicReconstruction( + PositionsConstraintsMixin, + ProbeMixedConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Mixed-State Ptychographic Reconstruction Class. @@ -1047,68 +1059,6 @@ def _adjoint( return current_object, current_probe - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - probe_intensity = xp.abs(current_probe[0]) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_orthogonalization_constraint(self, current_probe): - """ - Ptychographic probe-orthogonalization constraint. - Used to ensure mixed states are orthogonal to each other. - Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Orthogonalized probe estimate - """ - xp = self._xp - n_probes = self._num_probes - - # compute upper half of P* @ P - pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) - - for i in range(n_probes): - for j in range(i, n_probes): - pairwise_dot_product[i, j] = xp.sum( - current_probe[i].conj() * current_probe[j] - ) - - # compute eigenvectors (effectively cheaper way of computing V* from SVD) - _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") - current_probe = xp.tensordot(evecs.T, current_probe, axes=1) - - # sort by real-space intensity - intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) - intensities_order = xp.argsort(intensities, axis=None)[::-1] - return current_probe[intensities_order] - def _constraints( self, current_object, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 39cb62fdd..eaa0ae396 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -16,14 +16,16 @@ import cupy as cp except (ModuleNotFoundError, ImportError): cp = np - import os - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -38,7 +40,13 @@ warnings.simplefilter(action="always", category=UserWarning) -class MultislicePtychographicReconstruction(PtychographicReconstruction): +class MultislicePtychographicReconstruction( + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Multislice Ptychographic Reconstruction Class. @@ -1402,219 +1410,6 @@ def _position_correction( return current_positions - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - 2D Butterworth filter - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_kz_regularization_constraint( - self, current_object, kz_regularization_gamma - ): - """ - Arctan regularization filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - kz_regularization_gamma: float - Slice regularization strength - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - current_object = xp.pad( - current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" - ) - - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - - kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] - - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qz2 = qza**2 * kz_regularization_gamma**2 - qr2 = qxa**2 + qya**2 - - w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) - - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) - current_object = current_object[1:] - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_identical_slices_constraint(self, current_object): - """ - Strong regularization forcing all slices to be identical - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - object_mean = current_object.mean(0, keepdims=True) - current_object[:] = object_mean - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - def _constraints( self, current_object, @@ -1721,8 +1516,8 @@ def _constraints( If True, performs TV denoising along z tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising + tv_denoise_pad_chambolle: int + if not None, pads object at top and bottom with this many zeros before applying denoising tv_denoise: bool If True, applies TV denoising on object tv_denoise_weights: [float,float] @@ -1758,20 +1553,23 @@ def _constraints( current_object = self._object_identical_slices_constraint(current_object) elif kz_regularization_filter: current_object = self._object_kz_regularization_constraint( - current_object, kz_regularization_gamma + current_object, + kz_regularization_gamma, + z_padding=1, ) elif tv_denoise: current_object = self._object_denoise_tv_pylops( current_object, tv_denoise_weights, tv_denoise_inner_iter, + z_padding=1, ) elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad_chambolle, + padding=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1872,7 +1670,7 @@ def reconstruct( pure_phase_object_iter: int = 0, tv_denoise_iter_chambolle=np.inf, tv_denoise_weight_chambolle=None, - tv_denoise_pad_chambolle=True, + tv_denoise_pad_chambolle=1, tv_denoise_iter=np.inf, tv_denoise_weights=None, tv_denoise_inner_iter=40, @@ -1974,8 +1772,8 @@ def reconstruct( Number of iterations with TV denoisining tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising + tv_denoise_pad_chambolle: int + if not None, pads object at top and bottom with this many zeros before applying denoising tv_denoise: bool If True, applies TV denoising on object tv_denoise_weights: [float,float] diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 670ea5e40..e9c15b097 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -18,15 +18,16 @@ import cupy as cp except (ModuleNotFoundError, ImportError): cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -41,7 +42,13 @@ warnings.simplefilter(action="always", category=UserWarning) -class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): +class OverlapMagneticTomographicReconstruction( + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Overlap Magnetic Tomographic Reconstruction Class. @@ -1627,72 +1634,6 @@ def _position_correction( return current_positions - def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - gaussian_filter = self._gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - Butterworth filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qra = xp.sqrt(qza**2 + qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) - current_object += current_object_mean - - return xp.real(current_object) - def _divergence_free_constraint(self, vector_field): """ Leray projection operator @@ -1716,111 +1657,6 @@ def _divergence_free_constraint(self, vector_field): return vector_field - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - def _constraints( self, current_object, @@ -1938,16 +1774,16 @@ def _constraints( if gaussian_filter: current_object[0] = self._object_gaussian_constraint( - current_object[0], gaussian_filter_sigma_e + current_object[0], gaussian_filter_sigma_e, pure_phase_object=False ) current_object[1] = self._object_gaussian_constraint( - current_object[1], gaussian_filter_sigma_m + current_object[1], gaussian_filter_sigma_m, pure_phase_object=False ) current_object[2] = self._object_gaussian_constraint( - current_object[2], gaussian_filter_sigma_m + current_object[2], gaussian_filter_sigma_m, pure_phase_object=False ) current_object[3] = self._object_gaussian_constraint( - current_object[3], gaussian_filter_sigma_m + current_object[3], gaussian_filter_sigma_m, pure_phase_object=False ) if butterworth_filter: diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 749028b83..df5064206 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -18,15 +18,16 @@ import cupy as cp except (ModuleNotFoundError, ImportError): cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -40,7 +41,13 @@ warnings.simplefilter(action="always", category=UserWarning) -class OverlapTomographicReconstruction(PtychographicReconstruction): +class OverlapTomographicReconstruction( + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Overlap Tomographic Reconstruction Class. @@ -1529,175 +1536,6 @@ def _position_correction( return current_positions - def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - gaussian_filter = self._gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - Butterworth filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qra = xp.sqrt(qza**2 + qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) - current_object += current_object_mean - return xp.real(current_object) - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - def _constraints( self, current_object, @@ -1805,7 +1643,7 @@ def _constraints( if gaussian_filter: current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma + current_object, gaussian_filter_sigma, pure_phase_object=False ) if butterworth_filter: diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 59bf61da2..538b9d7aa 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -21,9 +21,9 @@ import pylops # this must follow the exception -class PtychographicConstraints: +class ObjectNDConstraintsMixin: """ - Container class for PtychographicReconstruction methods. + Mixin class for object constraints applicable to 2D,2.5D, and 3D objects. """ def _object_threshold_constraint(self, current_object, pure_phase_object): @@ -44,14 +44,18 @@ def _object_threshold_constraint(self, current_object, pure_phase_object): Constrained object estimate """ xp = self._xp - phase = xp.angle(current_object) - if pure_phase_object: - amplitude = 1.0 - else: - amplitude = xp.minimum(xp.abs(current_object), 1.0) + if self._object_type == "complex": + phase = xp.angle(current_object) + + if pure_phase_object: + amplitude = 1.0 + else: + amplitude = xp.minimum(xp.abs(current_object), 1.0) - return amplitude * xp.exp(1.0j * phase) + return amplitude * xp.exp(1.0j * phase) + else: + return current_object def _object_shrinkage_constraint(self, current_object, shrinkage_rad, object_mask): """ @@ -107,9 +111,10 @@ def _object_positivity_constraint(self, current_object): constrained_object: np.ndarray Constrained object estimate """ - xp = self._xp - - return xp.maximum(current_object, 0.0) + if self._object_type == "complex": + return current_object + else: + return current_object.clip(0.0) def _object_gaussian_constraint( self, current_object, gaussian_filter_sigma, pure_phase_object @@ -136,12 +141,12 @@ def _object_gaussian_constraint( gaussian_filter = self._gaussian_filter gaussian_filter_sigma /= self.sampling[0] - if pure_phase_object: + if not pure_phase_object or self._object_type == "potential": + current_object = gaussian_filter(current_object, gaussian_filter_sigma) + else: phase = xp.angle(current_object) phase = gaussian_filter(phase, gaussian_filter_sigma) current_object = xp.exp(1.0j * phase) - else: - current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object @@ -185,7 +190,7 @@ def _object_butterworth_constraint( if q_lowpass: env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - current_object_mean = xp.mean(current_object) + current_object_mean = xp.mean(current_object, axis=(-2, -1), keepdims=True) current_object -= current_object_mean current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) current_object += current_object_mean @@ -216,12 +221,12 @@ def _object_denoise_tv_pylops(self, current_object, weight, iterations): Constrained object estimate """ - xp = self._xp - - if xp.iscomplexobj(current_object): + if self._object_type == "complex": current_object_tv = current_object warnings.warn( - ("TV denoising is currently only supported for potential objects."), + ( + "TV denoising is currently only supported for object_type=='potential'." + ), UserWarning, ) @@ -257,7 +262,7 @@ def _object_denoise_tv_chambolle( current_object, weight, axis, - pad_object, + padding, eps=2.0e-4, max_num_iter=200, scaling=None, @@ -298,14 +303,19 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - if xp.iscomplexobj(current_object): + + if self._object_type == "complex": updated_object = current_object warnings.warn( - ("TV denoising is currently only supported for potential objects."), + ( + "TV denoising is currently only supported for object_type=='potential'." + ), UserWarning, ) + else: current_object_sum = xp.sum(current_object) + if axis is None: ndim = xp.arange(current_object.ndim).tolist() elif isinstance(axis, tuple): @@ -313,11 +323,13 @@ def _object_denoise_tv_chambolle( else: ndim = [axis] - if pad_object: + if padding is not None: pad_width = ((0, 0),) * current_object.ndim pad_width = list(pad_width) + for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) + pad_width[ndim[ax]] = (padding, padding) + current_object = xp.pad( current_object, pad_width=pad_width, mode="constant" ) @@ -383,16 +395,316 @@ def _object_denoise_tv_chambolle( E_previous = E i += 1 - if pad_object: + if padding is not None: for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + slices = array_slice( + ndim[ax], current_object.ndim, padding, -padding + ) updated_object = updated_object[slices] + updated_object = ( updated_object / xp.sum(updated_object) * current_object_sum ) return updated_object + +class Object2p5DConstraintsMixin: + """ + Mixin class for object constraints unique to 2.5D objects. + Overwrites ObjectNDConstraintsMixin. + """ + + def _object_denoise_tv_pylops(self, current_object, weights, iterations, z_padding): + """ + Performs second order TV denoising along x and y, and first order along z + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z_weight, r_weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + z_padding: int + Symmetric padding around the first axis + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[ + z_padding:-z_padding + ] + + return current_object_tv + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma, z_padding + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + z_padding: int + Symmetric padding around the first axis + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + # zero pad at top and bottom slice + pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) + current_object = xp.pad(current_object, pad_width=pad_width, mode="constant") + + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[z_padding:-z_padding] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + +class Object3DConstraintsMixin: + """ + Mixin class for object constraints unique to 3D objects. + Overwrites ObjectNDConstraintsMixin and Object2p5DConstraintsMixin. + """ + + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + xyz_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(0, 1, 2), edge=False, kind="backward" + ) + + l1_regs = [xyz_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + Butterworth filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qra = xp.sqrt(qza**2 + qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + +class ProbeConstraintsMixin: + """ + Mixin class for regularizations applicable to a single probe. + """ + def _probe_center_of_mass_constraint(self, current_probe): """ Ptychographic center of mass constraint. @@ -580,6 +892,81 @@ def _probe_aberration_fitting_constraint( return current_probe + +class ProbeMixedConstraintsMixin: + """ + Mixin class for regularizations unique to mixed probes. + Overwrites ProbeConstraintsMixin. + """ + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + +class PositionsConstraintsMixin: + """ + Mixin class for probe positions constraints. + """ + def _positions_center_of_mass_constraint(self, current_positions): """ Ptychographic position center of mass constraint. diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index c8cc5ee3e..aecbf0970 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -20,6 +20,11 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -32,7 +37,12 @@ warnings.simplefilter(action="always", category=UserWarning) -class SimultaneousPtychographicReconstruction(PtychographicReconstruction): +class SimultaneousPtychographicReconstruction( + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Iterative Simultaneous Ptychographic Reconstruction Class. diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 36baac21e..47de9ad29 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -20,6 +20,11 @@ from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -32,7 +37,12 @@ warnings.simplefilter(action="always", category=UserWarning) -class SingleslicePtychographicReconstruction(PtychographicReconstruction): +class SingleslicePtychographicReconstruction( + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + PtychographicReconstruction, +): """ Iterative Ptychographic Reconstruction Class. From 92e44e0a46ba416b0af3fa862510b35c0ff67b7c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 10:50:09 -0800 Subject: [PATCH 003/128] fixing reshaping bugs Former-commit-id: a4cb6f9e184871373c1e6baa48c83c187bf88416 --- .../process/phase/iterative_base_class.py | 55 +++++++++++++++++-- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index e82816042..c321d98e4 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -9,7 +9,7 @@ from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex -from scipy.ndimage import rotate +from scipy.ndimage import rotate, zoom try: import cupy as cp @@ -190,6 +190,7 @@ def _preprocess_datacube_and_vacuum_probe( datacube: Datacube Resampled and Padded datacube """ + if com_shifts is not None: if np.isscalar(com_shifts[0]): com_shifts = ( @@ -224,12 +225,31 @@ def _preprocess_datacube_and_vacuum_probe( datacube = datacube.bin_Q(N=bin_factor) if vacuum_probe_intensity is not None: - vacuum_probe_intensity = vacuum_probe_intensity[ - ::bin_factor, ::bin_factor - ] + # crop edges if necessary + if Qx % bin_factor == 0: + vacuum_probe_intensity = vacuum_probe_intensity[ + : -(Qx % bin_factor), : + ] + if Qy % bin_factor == 0: + vacuum_probe_intensity = vacuum_probe_intensity[ + :, : -(Qy % bin_factor) + ] + + vacuum_probe_intensity = vacuum_probe_intensity.reshape( + Qx // bin_factor, bin_factor, Qy // bin_factor, bin_factor + ).sum(axis=(1, 3)) if dp_mask is not None: - dp_mask = dp_mask[::bin_factor, ::bin_factor] - else: + # crop edges if necessary + if Qx % bin_factor == 0: + dp_mask = dp_mask[: -(Qx % bin_factor), :] + if Qy % bin_factor == 0: + dp_mask = dp_mask[:, : -(Qy % bin_factor)] + + dp_mask = dp_mask.reshape( + Qx // bin_factor, bin_factor, Qy // bin_factor, bin_factor + ).sum(axis=(1, 3)) + + elif reshaping_method == "fourier": datacube = datacube.resample_Q( N=resampling_factor_x, method=reshaping_method ) @@ -246,6 +266,29 @@ def _preprocess_datacube_and_vacuum_probe( force_nonnegative=True, ) + elif reshaping_method == "bilinear": + datacube = datacube.resample_Q( + N=resampling_factor_x, method=reshaping_method + ) + if vacuum_probe_intensity is not None: + vacuum_probe_intensity = zoom( + vacuum_probe_intensity, + (resampling_factor_x, resampling_factor_x), + order=1, + ) + if dp_mask is not None: + dp_mask = zoom( + dp_mask, (resampling_factor_x, resampling_factor_x), order=1 + ) + + else: + raise ValueError( + ( + "reshaping_method needs to be one of 'bilinear', 'fourier', or 'bin', " + f"not {reshaping_method}." + ) + ) + if probe_roi_shape is not None: Qx, Qy = datacube.shape[-2:] Sx, Sy = probe_roi_shape From c14fbf59e16c09002674cfe8b7ce8499bf3056f6 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 11:02:56 -0800 Subject: [PATCH 004/128] cleaning up calibrations Former-commit-id: 7415621d13a3bc9a97528163dd25c0bb7db5d9e2 --- py4DSTEM/process/phase/iterative_base_class.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index c321d98e4..8ec6b3dda 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -413,15 +413,16 @@ def _extract_intensities_and_calibrations_from_datacube( # Reciprocal-space if force_angular_sampling is not None or force_reciprocal_sampling is not None: - # there is no xor keyword in Python! - angular = force_angular_sampling is not None - reciprocal = force_reciprocal_sampling is not None - assert (angular and not reciprocal) or ( - not angular and reciprocal - ), "Only one of angular or reciprocal calibration can be forced!" + if ( + force_angular_sampling is not None + and force_reciprocal_sampling is not None + ): + raise ValueError( + "Only one of angular or reciprocal calibration can be forced." + ) # angular calibration specified - if angular: + if force_angular_sampling is not None: self._angular_sampling = (force_angular_sampling,) * 2 self._angular_units = ("mrad",) * 2 @@ -434,7 +435,7 @@ def _extract_intensities_and_calibrations_from_datacube( self._reciprocal_units = ("A^-1",) * 2 # reciprocal calibration specified - if reciprocal: + if force_reciprocal_sampling is not None: self._reciprocal_sampling = (force_reciprocal_sampling,) * 2 self._reciprocal_units = ("A^-1",) * 2 From 1f88bc78543f1cb8137b2611c1d596a973ace7c7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 11:10:05 -0800 Subject: [PATCH 005/128] cleaning up calculate rotation Former-commit-id: b995222b5b396d6ab1c9f7b5b4270b6d75a8a23f --- py4DSTEM/process/phase/iterative_base_class.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 8ec6b3dda..24984de04 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -544,7 +544,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -615,7 +614,7 @@ def _solve_for_center_of_mass_relative_rotation( _com_measured_y: np.ndarray, _com_normalized_x: np.ndarray, _com_normalized_y: np.ndarray, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_rotation: bool = True, plot_center_of_mass: str = "default", maximize_divergence: bool = False, @@ -680,6 +679,9 @@ def _solve_for_center_of_mass_relative_rotation( xp = self._xp asnumpy = self._asnumpy + if rotation_angles_deg is None: + rotation_angles_deg = np.arange(-89.0, 90.0, 1.0) + if force_com_rotation is not None: # Rotation known @@ -1132,7 +1134,7 @@ def _solve_for_center_of_mass_relative_rotation( ax.set_xlabel(f"y [{self._scan_units[1]}]") ax.set_title(title) - elif plot_center_of_mass == "default": + elif plot_center_of_mass == "default" or plot_center_of_mass is True: figsize = kwargs.pop("figsize", (8, 4)) cmap = kwargs.pop("cmap", "RdBu_r") From 6b6ec6ab94e99868fb3332042cf1061995b43b24 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 16:22:06 -0800 Subject: [PATCH 006/128] adding grid search functionality to optimize Former-commit-id: 8021a6037bec773faecbbd7cf98ea91901be9fd7 --- py4DSTEM/process/phase/parameter_optimize.py | 144 +++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 91a71cb30..49f131585 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -1,4 +1,5 @@ from functools import partial +from itertools import product from typing import Callable, Union import matplotlib.pyplot as plt @@ -102,6 +103,149 @@ def __init__( self._set_optimizer_defaults() + def _generate_inclusive_boundary_grid( + self, + parameter, + n_points, + ): + """ """ + + # Categorical + if hasattr(parameter, "categories"): + return np.array(parameter.categories) + + # Real or Integer + else: + return np.unique( + np.linspace(parameter.low, parameter.high, n_points).astype( + parameter.dtype + ) + ) + + def grid_search( + self, + n_points: Union[tuple, int] = 3, + error_metric: Union[Callable, str] = "log", + plot_reconstructed_objects: bool = True, + return_reconstructed_objects: bool = False, + **kwargs: dict, + ): + """ + Run optimizer + + Parameters + ---------- + n_initial_points: int + Number of uniformly spaced trial points to run on a grid + error_metric: Callable or str + Function used to compute the reconstruction error. + When passed as a string, may be one of: + 'log': log(NMSE) of final object + 'linear': NMSE of final object + 'log-converged': log(NMSE) of final object if + NMSE is decreasing, 0 if NMSE increasing + 'linear-converged': NMSE of final object if + NMSE is decreasing, 1 if NMSE increasing + 'TV': sum( abs( grad( object ) ) ) / sum( abs( object ) ) + 'std': negative standard deviation of cropped object + 'std-phase': negative standard deviation of + phase of the cropped object + 'entropy-phase': entropy of the phase of the + cropped object + When passed as a Callable, a function that takes the + PhaseReconstruction object as its only argument + and returns the error metric as a single float + + """ + + num_params = len(self._parameter_list) + + if isinstance(n_points, int): + n_points = [n_points] * num_params + elif len(n_points) != num_params: + raise ValueError() + + params_grid = [ + self._generate_inclusive_boundary_grid(param, n_pts) + for param, n_pts in zip(self._parameter_list, n_points) + ] + params_grid = list(product(*params_grid)) + num_evals = len(params_grid) + + error_metric = self._get_error_metric(error_metric) + pbar = tqdm(total=num_evals, desc="Searching parameters") + + def evaluation_callback(ptycho): + if plot_reconstructed_objects or return_reconstructed_objects: + pbar.update(1) + return (ptycho.object_cropped, error_metric(ptycho)) + else: + pbar.update(1) + error_metric(ptycho) + + self._grid_search_function = self._get_optimization_function( + self._reconstruction_type, + self._parameter_list, + self._init_static_args, + self._affine_static_args, + self._preprocess_static_args, + self._reconstruction_static_args, + self._init_optimize_args, + self._affine_optimize_args, + self._preprocess_optimize_args, + self._reconstruction_optimize_args, + evaluation_callback, + ) + + grid_search_res = list(map(self._grid_search_function, params_grid)) + pbar.close() + + if plot_reconstructed_objects: + + if len(n_points) == 2: + nrows, ncols = n_points + else: + nrows = kwargs.pop("nrows", int(np.sqrt(num_evals))) + ncols = kwargs.pop("ncols", int(np.ceil(num_evals / nrows))) + if nrows * ncols != num_evals: + raise ValueError() + + spec = GridSpec( + ncols=ncols, + nrows=nrows, + hspace=0.15, + wspace=0.15, + ) + + sx, sy = grid_search_res[0][0].shape + + separator = kwargs.pop("separator", "\n") + cmap = kwargs.pop("cmap", "magma") + figsize = kwargs.pop("figsize", (2.5 * ncols, 3 / sy * sx * nrows)) + fig = plt.figure(figsize=figsize) + + for index, (params, res) in enumerate(zip(params_grid, grid_search_res)): + row_index, col_index = np.unravel_index(index, (nrows, ncols)) + + ax = fig.add_subplot(spec[row_index, col_index]) + ax.imshow(res[0], cmap=cmap) + + title_substrings = [ + f"{param.name}: {val}" + for param, val in zip(self._parameter_list, params) + ] + title_substrings.append(f"error: {res[1]:.3e}") + title = separator.join(title_substrings) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(title) + spec.tight_layout(fig) + + if return_reconstructed_objects: + return grid_search_res + else: + return grid_search_res + def optimize( self, n_calls: int = 50, From 165e5f96021d214e17b30b34d5c71378c83ef126 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 16:23:53 -0800 Subject: [PATCH 007/128] removing tuning functions from subclasses - moved to optimizer Former-commit-id: b03fddb5308aea93459cdbc44a178e8c36ba4c35 --- .../process/phase/iterative_base_class.py | 199 +---------------- ...tive_mixedstate_multislice_ptychography.py | 204 ------------------ .../iterative_multislice_ptychography.py | 204 ------------------ 3 files changed, 4 insertions(+), 603 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 24984de04..a824874cd 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1215,6 +1215,7 @@ def _normalize_diffraction_intensities( else: number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + # Aggressive cropping for when off-centered high scattering angle data was recorded if crop_patterns: crop_x = int( np.minimum( @@ -1339,7 +1340,7 @@ def show_complex_CoM( class PtychographicReconstruction(PhaseReconstruction): """ Base ptychographic reconstruction class. - Inherits from PhaseReconstruction and PtychographicConstraints. + Inherits from PhaseReconstruction. Defines various common functions and properties for subclasses to inherit. """ @@ -1783,7 +1784,7 @@ def _crop_rotate_object_fov( Parameters ---------- array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. + Object array to crop and rotate. Only operates on numpy arrays for compatibility. padding: int, optional Optional padding outside pixel positions @@ -1810,7 +1811,7 @@ def _crop_rotate_object_fov( max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") rotated_array = rotate( - asnumpy(array), np.rad2deg(-angle), reshape=False, axes=(-2, -1) + asnumpy(array), np.rad2deg(-angle), order=1, reshape=False, axes=(-2, -1) )[..., min_x:max_x, min_y:max_y] if self._rotation_best_transpose: @@ -1818,198 +1819,6 @@ def _crop_rotate_object_fov( return rotated_array - def tune_angle_and_defocus( - self, - angle_guess=None, - defocus_guess=None, - transpose=None, - angle_step_size=1, - defocus_step_size=20, - num_angle_values=5, - num_defocus_values=5, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of angles and - defocus values. Should be run after preprocess step. - - Parameters - ---------- - angle_guess: float (degrees), optional - initial starting guess for rotation angle between real and reciprocal space - if None, uses current initialized values - defocus_guess: float (A), optional - initial starting guess for defocus - if None, uses current initialized values - angle_step_size: float (degrees), optional - size of change of rotation angle between real and reciprocal space for - each step in parameter space - defocus_step_size: float (A), optional - size of change of defocus for each step in parameter space - num_angle_values: int, optional - number of values of angle to test, must be >= 1. - num_defocus_values: int,optional - number of values of defocus to test, must be >= 1 - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction. - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - # calculate angles and defocus values to test - if angle_guess is None: - angle_guess = self._rotation_best_rad * 180 / np.pi - if defocus_guess is None: - defocus_guess = -self._polar_parameters["C10"] - if transpose is None: - transpose = self._rotation_best_transpose - - if num_angle_values == 1: - angle_step_size = 0 - - if num_defocus_values == 1: - defocus_step_size = 0 - - angles = np.linspace( - angle_guess - angle_step_size * (num_angle_values - 1) / 2, - angle_guess + angle_step_size * (num_angle_values - 1) / 2, - num_angle_values, - ) - - defocus_values = np.linspace( - defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2, - defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2, - num_defocus_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_defocus = -self._polar_parameters["C10"] - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values * 2, - height_ratios=[1, 1 / 4] * num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 5 * num_angle_values) - ) - else: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 4 * num_angle_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (angle, defocus) in enumerate( - tqdmnd(angles, defocus_values, desc="Tuning angle and defocus") - ): - self._polar_parameters["C10"] = -defocus - self._probe = None - self._object = None - self.preprocess( - force_com_rotation=angle, - force_com_transpose=transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - - self.reconstruct( - reset=True, - store_iterations=True, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_angle_values, num_defocus_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._polar_parameters["C10"] = -current_defocus - self._probe = None - self._object = None - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - def _position_correction( self, relevant_object, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 0f5b79e29..7ccbbd4bd 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3156,210 +3156,6 @@ def show_depth( fig.colorbar(im, cax=ax_cb) plt.tight_layout() - def tune_num_slices_and_thicknesses( - self, - num_slices_guess=None, - thicknesses_guess=None, - num_slices_step_size=1, - thicknesses_step_size=20, - num_slices_values=3, - num_thicknesses_values=3, - update_defocus=False, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of number of slices - and slice thicknesses. Should be run after the preprocess step. - - Parameters - ---------- - num_slices_guess: float, optional - initial starting guess for number of slices, rounds to nearest integer - if None, uses current initialized values - thicknesses_guess: float (A), optional - initial starting guess for thicknesses of slices assuming same - thickness for each slice - if None, uses current initialized values - num_slices_step_size: float, optional - size of change of number of slices for each step in parameter space - thicknesses_step_size: float (A), optional - size of change of slice thicknesses for each step in parameter space - num_slices_values: int, optional - number of number of slice values to test, must be >= 1 - num_thicknesses_values: int,optional - number of thicknesses values to test, must be >= 1 - update_defocus: bool, optional - if True, updates defocus based on estimated total thickness - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - - # calculate number of slices and thicknesses values to test - if num_slices_guess is None: - num_slices_guess = self._num_slices - if thicknesses_guess is None: - thicknesses_guess = np.mean(self._slice_thicknesses) - - if num_slices_values == 1: - num_slices_step_size = 0 - - if num_thicknesses_values == 1: - thicknesses_step_size = 0 - - num_slices = np.linspace( - num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_values, - ) - - thicknesses = np.linspace( - thicknesses_guess - - thicknesses_step_size * (num_thicknesses_values - 1) / 2, - thicknesses_guess - + thicknesses_step_size * (num_thicknesses_values - 1) / 2, - num_thicknesses_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_num_slices = self._num_slices - current_thicknesses = self._slice_thicknesses - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - current_defocus = -self._polar_parameters["C10"] - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values * 2, - height_ratios=[1, 1 / 4] * num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) - ) - else: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (slices, thickness) in enumerate( - tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") - ): - slices = int(slices) - self._num_slices = slices - self._slice_thicknesses = np.tile(thickness, slices - 1) - self._probe = None - self._object = None - if update_defocus: - defocus = current_defocus + slices / 2 * thickness - self._polar_parameters["C10"] = -defocus - - self.preprocess( - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - ) - self.reconstruct( - reset=True, - store_iterations=True if plot_convergence else False, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_slices_values, num_thicknesses_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._probe = None - self._object = None - self._num_slices = current_num_slices - self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) - self._polar_parameters["C10"] = -current_defocus - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - def _return_object_fft( self, obj=None, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index eaa0ae396..1baab6f21 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3023,210 +3023,6 @@ def show_depth( fig.colorbar(im, cax=ax_cb) plt.tight_layout() - def tune_num_slices_and_thicknesses( - self, - num_slices_guess=None, - thicknesses_guess=None, - num_slices_step_size=1, - thicknesses_step_size=20, - num_slices_values=3, - num_thicknesses_values=3, - update_defocus=False, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of number of slices - and slice thicknesses. Should be run after the preprocess step. - - Parameters - ---------- - num_slices_guess: float, optional - initial starting guess for number of slices, rounds to nearest integer - if None, uses current initialized values - thicknesses_guess: float (A), optional - initial starting guess for thicknesses of slices assuming same - thickness for each slice - if None, uses current initialized values - num_slices_step_size: float, optional - size of change of number of slices for each step in parameter space - thicknesses_step_size: float (A), optional - size of change of slice thicknesses for each step in parameter space - num_slices_values: int, optional - number of number of slice values to test, must be >= 1 - num_thicknesses_values: int,optional - number of thicknesses values to test, must be >= 1 - update_defocus: bool, optional - if True, updates defocus based on estimated total thickness - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - - # calculate number of slices and thicknesses values to test - if num_slices_guess is None: - num_slices_guess = self._num_slices - if thicknesses_guess is None: - thicknesses_guess = np.mean(self._slice_thicknesses) - - if num_slices_values == 1: - num_slices_step_size = 0 - - if num_thicknesses_values == 1: - thicknesses_step_size = 0 - - num_slices = np.linspace( - num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_values, - ) - - thicknesses = np.linspace( - thicknesses_guess - - thicknesses_step_size * (num_thicknesses_values - 1) / 2, - thicknesses_guess - + thicknesses_step_size * (num_thicknesses_values - 1) / 2, - num_thicknesses_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_num_slices = self._num_slices - current_thicknesses = self._slice_thicknesses - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - current_defocus = -self._polar_parameters["C10"] - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values * 2, - height_ratios=[1, 1 / 4] * num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) - ) - else: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (slices, thickness) in enumerate( - tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") - ): - slices = int(slices) - self._num_slices = slices - self._slice_thicknesses = np.tile(thickness, slices - 1) - self._probe = None - self._object = None - if update_defocus: - defocus = current_defocus + slices / 2 * thickness - self._polar_parameters["C10"] = -defocus - - self.preprocess( - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - ) - self.reconstruct( - reset=True, - store_iterations=True if plot_convergence else False, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_slices_values, num_thicknesses_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._probe = None - self._object = None - self._num_slices = current_num_slices - self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) - self._polar_parameters["C10"] = -current_defocus - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - def _return_object_fft( self, obj=None, From 2c2cd40856180ffb7f0dc2a7e0d682ddcfe88be0 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 16:35:40 -0800 Subject: [PATCH 008/128] small typo Former-commit-id: a6d35bb03b3292a118cd8a5fa710f5dbd56ce327 --- py4DSTEM/process/phase/parameter_optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 49f131585..ba3614f40 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -207,7 +207,7 @@ def evaluation_callback(ptycho): else: nrows = kwargs.pop("nrows", int(np.sqrt(num_evals))) ncols = kwargs.pop("ncols", int(np.ceil(num_evals / nrows))) - if nrows * ncols != num_evals: + if nrows * ncols < num_evals: raise ValueError() spec = GridSpec( From 3407e8676b0bf52f60076ec1f30fae9cdc6aec60 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 17:17:34 -0800 Subject: [PATCH 009/128] cupy numpy bug Former-commit-id: 923388028b0f39b5df7826de38ad5dfc34801711 --- py4DSTEM/process/phase/utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 31e1aac65..95dbc4511 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -3,14 +3,16 @@ import matplotlib.pyplot as plt import numpy as np +from scipy.fft import dctn, dstn, idctn, idstn from scipy.optimize import curve_fit try: import cupy as cp - from cupyx.scipy.fft import dctn, idctn, rfft + from cupyx.scipy.fft import dctn as dctn_cp + from cupyx.scipy.fft import idctn as idctn_cp + from cupyx.scipy.fft import rfft except (ImportError, ModuleNotFoundError): cp = None - from scipy.fft import dstn, idstn, dctn, idctn from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images @@ -1627,9 +1629,16 @@ def preconditioned_poisson_solver_dct(rhs, gauge=None, xp=np): if gauge is None: gauge = xp.mean(rhs) - fft_rhs = dctn(rhs, type=2) + if xp is np: + dctn_xp = dctn + idctn_xp = idctn + else: + dctn_xp = dctn_cp + idctn_xp = idctn_cp + + fft_rhs = dctn_xp(rhs, type=2) fft_rhs[0, 0] = gauge # gauge invariance - sol = idctn(fft_rhs / op, type=2) + sol = idctn_xp(fft_rhs / op, type=2) return sol From 13d4fb65842099c9300665ab7c73934b5906e572 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 19:00:10 -0800 Subject: [PATCH 010/128] refactoring object and probe methods Former-commit-id: 47adef700f4dbfcbb2ad26ed773d454c61b43079 --- .../process/phase/iterative_base_class.py | 308 +------ ...tive_mixedstate_multislice_ptychography.py | 367 +------- .../iterative_mixedstate_ptychography.py | 78 +- .../iterative_multislice_ptychography.py | 297 +----- .../iterative_overlap_magnetic_tomography.py | 163 +--- .../phase/iterative_overlap_tomography.py | 162 +--- .../phase/iterative_ptychographic_methods.py | 859 ++++++++++++++++++ .../iterative_simultaneous_ptychography.py | 6 + .../iterative_singleslice_ptychography.py | 6 + py4DSTEM/process/phase/parameter_optimize.py | 1 - 10 files changed, 917 insertions(+), 1330 deletions(-) create mode 100644 py4DSTEM/process/phase/iterative_ptychographic_methods.py diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index a824874cd..9bfbca3ba 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -8,15 +8,15 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex -from scipy.ndimage import rotate, zoom +from py4DSTEM.visualize import return_scaled_histogram_ordering, show_complex +from scipy.ndimage import zoom try: import cupy as cp except (ModuleNotFoundError, ImportError): cp = np -from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin @@ -1773,52 +1773,6 @@ def _extract_vectorized_patch_indices(self): return vectorized_patch_indices_row, vectorized_patch_indices_col - def _crop_rotate_object_fov( - self, - array, - padding=0, - ): - """ - Crops and rotated object to FOV bounded by current pixel positions. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for compatibility. - padding: int, optional - Optional padding outside pixel positions - - Returns - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) - - tf = AffineTransform(angle=angle) - rotated_points = tf( - asnumpy(self._positions_px), origin=asnumpy(self._positions_px_com), xp=np - ) - - min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") - min_x = min_x if min_x > 0 else 0 - min_y = min_y if min_y > 0 else 0 - max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") - - rotated_array = rotate( - asnumpy(array), np.rad2deg(-angle), order=1, reshape=False, axes=(-2, -1) - )[..., min_x:max_x, min_y:max_y] - - if self._rotation_best_transpose: - rotated_array = rotated_array.swapaxes(-2, -1) - - return rotated_array - def _position_correction( self, relevant_object, @@ -2003,119 +1957,6 @@ def plot_position_correction( ax.set_aspect("equal") ax.set_title("Probe positions correction") - def _return_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - ): - """ - Returns complex fourier probe shifted to center of array from - corner-centered complex real space probe - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - - Returns - ------- - fourier_probe: np.ndarray - Fourier-transformed and center-shifted probe. - """ - xp = self._xp - - if probe is None: - probe = self._probe - else: - probe = xp.asarray(probe, dtype=xp.complex64) - - fourier_probe = xp.fft.fft2(probe) - - if remove_initial_probe_aberrations: - fourier_probe *= xp.conjugate(self._known_aberrations_array) - - return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) - - def _return_fourier_probe_from_centered_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - ): - """ - Returns complex fourier probe shifted to center of array from - centered complex real space probe - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - - Returns - ------- - fourier_probe: np.ndarray - Fourier-transformed and center-shifted probe. - """ - xp = self._xp - return self._return_fourier_probe( - xp.fft.ifftshift(probe, axes=(-2, -1)), - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - - def _return_centered_probe( - self, - probe=None, - ): - """ - Returns complex probe centered in middle of the array. - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - - Returns - ------- - centered_probe: np.ndarray - Center-shifted probe. - """ - xp = self._xp - - if probe is None: - probe = self._probe - else: - probe = xp.asarray(probe, dtype=xp.complex64) - - return xp.fft.fftshift(probe, axes=(-2, -1)) - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns absolute value of obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - - Returns - ------- - object_fft_amplitude: np.ndarray - Amplitude of Fourier-transformed and center-shifted obj. - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = self._crop_rotate_object_fov(asnumpy(obj)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - def _return_self_consistency_errors( self, max_batch_size=None, @@ -2162,17 +2003,6 @@ def _return_self_consistency_errors( return asnumpy(errors) - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped) - else: - projected_cropped_potential = self.object_cropped - - return projected_cropped_potential - def show_uncertainty_visualization( self, errors=None, @@ -2353,132 +2183,6 @@ def show_uncertainty_visualization( spec.tight_layout(fig) - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - cbar: bool, optional - if True, adds colorbar - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - probe = asnumpy( - self._return_fourier_probe( - probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations - ) - ) - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - figsize = kwargs.pop("figsize", (6, 6)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - fig, ax = plt.subplots(figsize=figsize) - show_complex( - probe, - cbar=cbar, - figax=(fig, ax), - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def show_object_fft(self, obj=None, **kwargs): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: complex array, optional - if None is specified, uses the `object_fft` property - """ - if obj is None: - object_fft = self.object_fft - else: - object_fft = self._return_object_fft(obj) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - - @property - def probe_fourier(self): - """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy(self._return_fourier_probe(self._probe)) - - @property - def probe_fourier_residual(self): - """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy( - self._return_fourier_probe( - self._probe, remove_initial_probe_aberrations=True - ) - ) - - @property - def probe_centered(self): - """Current probe estimate shifted to the center""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy(self._return_centered_probe(self._probe)) - - @property - def object_fft(self): - """Fourier transform of current object estimate""" - - if not hasattr(self, "_object"): - return None - - return self._return_object_fft(self._object) - @property def angular_sampling(self): """Angular sampling [mrad]""" @@ -2510,9 +2214,3 @@ def positions(self): positions[:, 1] *= self.sampling[1] return asnumpy(positions) - - @property - def object_cropped(self): - """Cropped and rotated object""" - - return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 7ccbbd4bd..2369eea38 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -27,6 +27,12 @@ ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + ProbeMethodsMixin, + ProbeMixedMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -36,7 +42,6 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -47,6 +52,10 @@ class MixedstateMultislicePtychographicReconstruction( ProbeConstraintsMixin, Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMixedMethodsMixin, + ProbeMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ @@ -2762,74 +2771,6 @@ def visualize( ) return self - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - def show_transmitted_probe( self, plot_fourier_probe: bool = False, @@ -2903,283 +2844,6 @@ def show_transmitted_probe( **kwargs, ) - def show_slices( - self, - ms_object=None, - cbar: bool = True, - common_color_scale: bool = True, - padding: int = 0, - num_cols: int = 3, - show_fft: bool = False, - **kwargs, - ): - """ - Displays reconstructed slices of object - - Parameters - -------- - ms_object: nd.array, optional - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - padding: int, optional - Padding to leave uncropped - num_cols: int, optional - Number of GridSpec columns - show_fft: bool, optional - if True, plots fft of object slices - """ - - if ms_object is None: - ms_object = self._object - - rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) - if show_fft: - rotated_object = np.abs( - np.fft.fftshift( - np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) - ) - ) - rotated_shape = rotated_object.shape - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - - extent = [ - 0, - self.sampling[1] * rotated_shape[2], - self.sampling[0] * rotated_shape[1], - 0, - ] - - num_rows = np.ceil(self._num_slices / num_cols).astype("int") - wspace = 0.35 if cbar else 0.15 - - axsize = kwargs.pop("axsize", (3, 3)) - cmap = kwargs.pop("cmap", "magma") - - if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] - else: - vmax = None - vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) - - spec = GridSpec( - ncols=num_cols, - nrows=num_rows, - hspace=0.15, - wspace=wspace, - ) - - figsize = (axsize[0] * num_cols, axsize[1] * num_rows) - fig = plt.figure(figsize=figsize) - - for flat_index, obj_slice in enumerate(rotated_object): - row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) - ax = fig.add_subplot(spec[row_index, col_index]) - im = ax.imshow( - obj_slice, - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - ax.set_title(f"Slice index: {flat_index}") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if row_index < num_rows - 1: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col_index > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - spec.tight_layout(fig) - - def show_depth( - self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, - plot_line_profile: bool = False, - **kwargs, - ): - """ - Displays line profile depth section - - Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True - specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels - gaussian_filter_sigma: float (optional) - Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken - """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped - - if specify_calibrated: - x1 /= self.sampling[0] - x2 /= self.sampling[0] - y1 /= self.sampling[1] - y2 /= self.sampling[1] - - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) - - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 - - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point - - x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) - x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) - - rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - int(x1_0), - axis=1, - ) - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] - - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - else: - extent2 = [ - 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], - 0, - ] - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( - [y1 * self.sampling[0], y2 * self.sampling[1]], - [x1 * self.sampling[0], x2 * self.sampling[1]], - color="red", - ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = asnumpy(obj) - if np.iscomplexobj(obj): - obj = np.angle(obj) - - obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - def _return_self_consistency_errors( self, max_batch_size=None, @@ -3226,14 +2890,3 @@ def _return_self_consistency_errors( errors /= self._mean_diffraction_intensity return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) - else: - projected_cropped_potential = self.object_cropped.sum(0) - - return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 2809f0144..b03bb456d 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -10,7 +10,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: import cupy as cp @@ -26,6 +26,11 @@ ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + ObjectNDMethodsMixin, + ProbeMethodsMixin, + ProbeMixedMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -43,6 +48,9 @@ class MixedstatePtychographicReconstruction( ProbeMixedConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMixedMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ @@ -2264,74 +2272,6 @@ def visualize( return self - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - def _return_self_consistency_errors( self, max_batch_size=None, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 1baab6f21..bb46f25c8 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -26,6 +26,11 @@ PositionsConstraintsMixin, ProbeConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + ProbeMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -35,7 +40,6 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -45,6 +49,9 @@ class MultislicePtychographicReconstruction( ProbeConstraintsMixin, Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ @@ -2769,291 +2776,3 @@ def show_transmitted_probe( title=title, **kwargs, ) - - def show_slices( - self, - ms_object=None, - cbar: bool = True, - common_color_scale: bool = True, - padding: int = 0, - num_cols: int = 3, - show_fft: bool = False, - **kwargs, - ): - """ - Displays reconstructed slices of object - - Parameters - -------- - ms_object: nd.array, optional - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - padding: int, optional - Padding to leave uncropped - num_cols: int, optional - Number of GridSpec columns - show_fft: bool, optional - if True, plots fft of object slices - """ - - if ms_object is None: - ms_object = self._object - - rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) - if show_fft: - rotated_object = np.abs( - np.fft.fftshift( - np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) - ) - ) - rotated_shape = rotated_object.shape - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - - extent = [ - 0, - self.sampling[1] * rotated_shape[2], - self.sampling[0] * rotated_shape[1], - 0, - ] - - num_rows = np.ceil(self._num_slices / num_cols).astype("int") - wspace = 0.35 if cbar else 0.15 - - axsize = kwargs.pop("axsize", (3, 3)) - cmap = kwargs.pop("cmap", "magma") - - if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] - else: - vmax = None - vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) - - spec = GridSpec( - ncols=num_cols, - nrows=num_rows, - hspace=0.15, - wspace=wspace, - ) - - figsize = (axsize[0] * num_cols, axsize[1] * num_rows) - fig = plt.figure(figsize=figsize) - - for flat_index, obj_slice in enumerate(rotated_object): - row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) - ax = fig.add_subplot(spec[row_index, col_index]) - im = ax.imshow( - obj_slice, - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - ax.set_title(f"Slice index: {flat_index}") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if row_index < num_rows - 1: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col_index > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - spec.tight_layout(fig) - - def show_depth( - self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, - plot_line_profile: bool = False, - **kwargs, - ): - """ - Displays line profile depth section - - Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True - specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels - gaussian_filter_sigma: float (optional) - Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken - """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped - - if specify_calibrated: - x1 /= self.sampling[0] - x2 /= self.sampling[0] - y1 /= self.sampling[1] - y2 /= self.sampling[1] - - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) - - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 - - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point - - x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) - x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) - - rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - -int(x1_0), - axis=1, - ) - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] - - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - else: - extent2 = [ - 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], - 0, - ] - - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( - [y1 * self.sampling[0], y2 * self.sampling[1]], - [x1 * self.sampling[0], x2 * self.sampling[1]], - color="red", - ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = asnumpy(obj) - if np.iscomplexobj(obj): - obj = np.angle(obj) - - obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) - else: - projected_cropped_potential = self.object_cropped.sum(0) - - return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index e9c15b097..db2feaf10 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -10,9 +10,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize import show from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg -from scipy.ndimage import rotate as rotate_np try: import cupy as cp @@ -28,6 +26,11 @@ PositionsConstraintsMixin, ProbeConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + Object3DMethodsMixin, + ObjectNDMethodsMixin, + ProbeMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -47,6 +50,9 @@ class OverlapMagneticTomographicReconstruction( ProbeConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMethodsMixin, + Object3DMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ @@ -2583,46 +2589,6 @@ def reconstruct( return self - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate_np( - asnumpy(array), angle, reshape=False, axes=(-2, -1) - ) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - def _visualize_last_iteration_figax( self, fig, @@ -3070,113 +3036,6 @@ def visualize( return self - def _return_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - xp = self._xp - asnumpy = self._asnumpy - - if obj is None: - obj = self._object[0] - else: - obj = xp.asarray(obj[0], dtype=xp.float32) - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - @property def positions(self): """Probe positions [A]""" @@ -3205,12 +3064,6 @@ def _return_self_consistency_errors( """Compute the self-consistency errors for each probe position""" raise NotImplementedError() - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - raise NotImplementedError() - def show_uncertainty_visualization( self, errors=None, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index df5064206..442cb79a0 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -10,7 +10,6 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize import show from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg from scipy.ndimage import rotate as rotate_np @@ -28,6 +27,11 @@ PositionsConstraintsMixin, ProbeConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + Object3DMethodsMixin, + ObjectNDMethodsMixin, + ProbeMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -46,6 +50,9 @@ class OverlapTomographicReconstruction( ProbeConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMethodsMixin, + Object3DMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ @@ -2306,46 +2313,6 @@ def reconstruct( return self - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate_np( - asnumpy(array), angle, reshape=False, axes=(-2, -1) - ) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - def _visualize_last_iteration_figax( self, fig, @@ -2969,113 +2936,6 @@ def visualize( return self - def _return_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - xp = self._xp - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - else: - obj = xp.asarray(obj, dtype=xp.float32) - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - @property def positions(self): """Probe positions [A]""" @@ -3104,12 +2964,6 @@ def _return_self_consistency_errors( """Compute the self-consistency errors for each probe position""" raise NotImplementedError() - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - raise NotImplementedError() - def show_uncertainty_visualization( self, errors=None, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py new file mode 100644 index 000000000..c86cf8650 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -0,0 +1,859 @@ +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.process.phase.utils import AffineTransform +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex +from scipy.ndimage import rotate + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +class ObjectNDMethodsMixin: + """ + Mixin class for object methods applicable to 2D,2.5D, and 3D objects. + """ + + def _crop_rotate_object_fov( + self, + array, + padding=0, + ): + """ + Crops and rotated object to FOV bounded by current pixel positions. + + Parameters + ---------- + array: np.ndarray + Object array to crop and rotate. Only operates on numpy arrays for compatibility. + padding: int, optional + Optional padding outside pixel positions + + Returns + cropped_rotated_array: np.ndarray + Cropped and rotated object array + """ + + asnumpy = self._asnumpy + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf( + asnumpy(self._positions_px), origin=asnumpy(self._positions_px_com), xp=np + ) + + min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") + min_x = min_x if min_x > 0 else 0 + min_y = min_y if min_y > 0 else 0 + max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") + + rotated_array = rotate( + asnumpy(array), np.rad2deg(-angle), order=1, reshape=False, axes=(-2, -1) + )[..., min_x:max_x, min_y:max_y] + + if self._rotation_best_transpose: + rotated_array = rotated_array.swapaxes(-2, -1) + + return rotated_array + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped) + else: + projected_cropped_potential = self.object_cropped + + return projected_cropped_potential + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns absolute value of obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + xp = self._xp + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + if np.iscomplexobj(obj): + obj = xp.angle(obj) + + obj = self._crop_rotate_object_fov(asnumpy(obj)) + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + def show_object_fft(self, obj=None, **kwargs): + """ + Plot FFT of reconstructed object + + Parameters + ---------- + obj: complex array, optional + if None is specified, uses the `object_fft` property + """ + if obj is None: + object_fft = self.object_fft + else: + object_fft = self._return_object_fft(obj) + + figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "magma") + + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + show( + object_fft, + figsize=figsize, + cmap=cmap, + scalebar=True, + pixelsize=pixelsize, + ticks=False, + pixelunits=r"$\AA^{-1}$", + **kwargs, + ) + + @property + def object_fft(self): + """Fourier transform of current object estimate""" + + if not hasattr(self, "_object"): + return None + + return self._return_object_fft(self._object) + + @property + def object_cropped(self): + """Cropped and rotated object""" + + return self._crop_rotate_object_fov(self._object) + + +class Object2p5DMethodsMixin: + """ + Mixin class for object methods unique to 2.5D objects. + Overwrites ObjectNDMethodsMixin. + """ + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + xp = self._xp + + if obj is None: + obj = self._object + + if np.iscomplexobj(obj): + obj = xp.angle(obj) + + obj = self._crop_rotate_object_fov(obj.sum(axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + -int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[ + :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) + ] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + show_fft: bool = False, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + +class Object3DMethodsMixin: + """ + Mixin class for object methods unique to 3D objects. + Overwrites ObjectNDMethodsMixin and Object2p5DMethodsMixin. + """ + + def _crop_rotate_object_manually( + self, + array, + angle, + x_lims, + y_lims, + ): + """ + Crops and rotates rotates object manually. + + Parameters + ---------- + array: np.ndarray + Object array to crop and rotate. Only operates on numpy arrays for compatibility. + angle: float + In-plane angle in degrees to rotate by + x_lims: tuple(float,float) + min/max x indices + y_lims: tuple(float,float) + min/max y indices + + Returns + ------- + cropped_rotated_array: np.ndarray + Cropped and rotated object array + """ + + asnumpy = self._asnumpy + min_x, max_x = x_lims + min_y, max_y = y_lims + + if angle is not None: + rotated_array = rotate(asnumpy(array), angle, reshape=False, axes=(-2, -1)) + else: + rotated_array = asnumpy(array) + + return rotated_array[..., min_x:max_x, min_y:max_y] + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def _return_object_fft( + self, + obj=None, + projection_angle_deg: float = None, + projection_axes: Tuple[int, int] = (0, 2), + x_lims: Tuple[int, int] = (None, None), + y_lims: Tuple[int, int] = (None, None), + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + projection_angle_deg: float + Angle in degrees to rotate 3D array around prior to projection + projection_axes: tuple(int,int) + Axes defining projection plane + x_lims: tuple(float,float) + min/max x indices + y_lims: tuple(float,float) + min/max y indices + """ + + xp = self._xp + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + else: + obj = xp.asarray(obj, dtype=xp.float32) + + if projection_angle_deg is not None: + rotated_3d_obj = self._rotate( + obj, + projection_angle_deg, + axes=projection_axes, + reshape=False, + order=2, + ) + rotated_3d_obj = asnumpy(rotated_3d_obj) + else: + rotated_3d_obj = asnumpy(obj) + + rotated_object = self._crop_rotate_object_manually( + rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims + ) + + return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) + + def show_object_fft( + self, + obj=None, + projection_angle_deg: float = None, + projection_axes: Tuple[int, int] = (0, 2), + x_lims: Tuple[int, int] = (None, None), + y_lims: Tuple[int, int] = (None, None), + **kwargs, + ): + """ + Plot FFT of reconstructed object + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + projection_angle_deg: float + Angle in degrees to rotate 3D array around prior to projection + projection_axes: tuple(int,int) + Axes defining projection plane + x_lims: tuple(float,float) + min/max x indices + y_lims: tuple(float,float) + min/max y indices + """ + if obj is None: + object_fft = self._return_object_fft( + projection_angle_deg=projection_angle_deg, + projection_axes=projection_axes, + x_lims=x_lims, + y_lims=y_lims, + ) + else: + object_fft = self._return_object_fft( + obj, + projection_angle_deg=projection_angle_deg, + projection_axes=projection_axes, + x_lims=x_lims, + y_lims=y_lims, + ) + + figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "magma") + + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + show( + object_fft, + figsize=figsize, + cmap=cmap, + scalebar=True, + pixelsize=pixelsize, + ticks=False, + pixelunits=r"$\AA^{-1}$", + **kwargs, + ) + + +class ProbeMethodsMixin: + """ + Mixin class for probe methods applicable to a single probe. + """ + + def _return_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + ): + """ + Returns complex fourier probe shifted to center of array from + corner-centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + fourier_probe = xp.fft.fft2(probe) + + if remove_initial_probe_aberrations: + fourier_probe *= xp.conjugate(self._known_aberrations_array) + + return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) + + def _return_fourier_probe_from_centered_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + ): + """ + Returns complex fourier probe shifted to center of array from + centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + return self._return_fourier_probe( + xp.fft.ifftshift(probe, axes=(-2, -1)), + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + + def _return_centered_probe( + self, + probe=None, + ): + """ + Returns complex probe centered in middle of the array. + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + + Returns + ------- + centered_probe: np.ndarray + Center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + return xp.fft.fftshift(probe, axes=(-2, -1)) + + def show_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + probe = asnumpy( + self._return_fourier_probe( + probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations + ) + ) + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 1) + + fig, ax = plt.subplots(figsize=figsize) + show_complex( + probe, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + chroma_boost=chroma_boost, + **kwargs, + ) + + @property + def probe_fourier(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_fourier_probe(self._probe)) + + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy( + self._return_fourier_probe( + self._probe, remove_initial_probe_aberrations=True + ) + ) + + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_centered_probe(self._probe)) + + +class ProbeMixedMethodsMixin: + """ + Mixin class for probe methods unique to mixed probes. + Overwrites ProbeMethodsMixin. + """ + + def show_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + chroma_boost = kwargs.pop("chroma_boost", 1) + + show_complex( + probe if len(probe) > 1 else probe[0], + cbar=cbar, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + chroma_boost=chroma_boost, + **kwargs, + ) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index aecbf0970..aafd58134 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -25,6 +25,10 @@ PositionsConstraintsMixin, ProbeConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + ObjectNDMethodsMixin, + ProbeMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -41,6 +45,8 @@ class SimultaneousPtychographicReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 47de9ad29..7434e2e50 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -25,6 +25,10 @@ PositionsConstraintsMixin, ProbeConstraintsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + ObjectNDMethodsMixin, + ProbeMethodsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -41,6 +45,8 @@ class SingleslicePtychographicReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, PtychographicReconstruction, ): """ diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index ba3614f40..8744ec792 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -201,7 +201,6 @@ def evaluation_callback(ptycho): pbar.close() if plot_reconstructed_objects: - if len(n_points) == 2: nrows, ncols = n_points else: From 8c581271d27e1a08f60c9d9f9b29d959fc290f82 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 30 Dec 2023 20:08:47 -0800 Subject: [PATCH 011/128] multislice plotting tweaks Former-commit-id: b44b2d2138ce4f3517742fba37657bb3530b74d7 --- .../phase/iterative_ptychographic_methods.py | 226 +++++++++--------- 1 file changed, 119 insertions(+), 107 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index c86cf8650..a0f10fd0d 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -4,9 +4,9 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.process.phase.utils import AffineTransform +from py4DSTEM.process.phase.utils import AffineTransform, rotate_point from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex -from scipy.ndimage import rotate +from scipy.ndimage import gaussian_filter, rotate try: import cupy as cp @@ -190,45 +190,61 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(obj.sum(axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(obj))) - def show_depth( + def show_depth_section( self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, + ptA: Tuple[float, float], + ptB: Tuple[float, float], + aspect_ratio: float = "auto", plot_line_profile: bool = False, + ms_object=None, + specify_calibrated: bool = True, + gaussian_filter_sigma: float = None, + cbar: bool = True, **kwargs, ): """ Displays line profile depth section Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True + ---------- + ptA: Tuple[float,float] + Starting point (x1,y1) for line profile depth section + If either is None, assumed to be array start. + Specified in Angstroms unless specify_calibrated is False + ptB: Tuple[float,float] + End point (x2,y2) for line profile depth section + If either is None, assumed to be array end. + Specified in Angstroms unless specify_calibrated is False + aspect_ratio: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + ms_object: np.array + Object to plot slices of. If None, uses current object specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels + If False, ptA and ptB points specified in pixels instead of Angstroms gaussian_filter_sigma: float (optional) Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object cbar: bool, optional If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped + if ms_object is None: + ms_object = self.object_cropped + + if np.iscomplexobj(ms_object): + ms_object = np.angle(ms_object) + + x1, y1 = ptA + x2, y2 = ptB + + if x1 is None: + x1 = 0 + if y1 is None: + y1 = 0 + if x2 is None: + x2 = self.sampling[0] * ms_object.shape[1] + if y2 is None: + y2 = self.sampling[1] * ms_object.shape[2] if specify_calibrated: x1 /= self.sampling[0] @@ -236,98 +252,101 @@ def show_depth( y1 /= self.sampling[1] y2 /= self.sampling[1] - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) + x1, x2 = np.array([x1, x2]).clip(0, ms_object.shape[1]) + y1, y2 = np.array([y1, y2]).clip(0, ms_object.shape[2]) - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 + angle = np.arctan2(x2 - x1, y2 - y1) - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point + x0 = ms_object.shape[1] / 2 + y0 = ms_object.shape[2] / 2 x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + rotate(ms_object, np.rad2deg(angle), reshape=False, axes=(-1, -2)), -int(x1_0), axis=1, ) - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - gaussian_filter_sigma /= self.sampling[0] rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] + y1_0, y2_0 = ( + np.array([y1_0, y2_0]).astype("int").clip(0, rotated_object.shape[2]) + ) + plot_im = rotated_object[:, 0, y1_0:y2_0] - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + # Plotting + if plot_line_profile: + ncols = 2 else: - extent2 = [ + ncols = 1 + col_index = 0 + + spec = GridSpec(ncols=ncols, nrows=1, wspace=0.15) + + figsize = kwargs.pop("figsize", (4 * ncols, 4)) + fig = plt.figure(figsize=figsize) + cmap = kwargs.pop("cmap", "magma") + + # Line profile + if plot_line_profile: + ax = fig.add_subplot(spec[0, col_index]) + + extent_line = [ 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], + self.sampling[1] * ms_object.shape[2], + self.sampling[0] * ms_object.shape[1], 0, ] - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( + ax.imshow(ms_object.sum(0), cmap="gray", extent=extent_line) + + ax.plot( [y1 * self.sampling[0], y2 * self.sampling[1]], [x1 * self.sampling[0], x2 * self.sampling[1]], color="red", ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() + + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile location") + col_index += 1 + + # Main visualization + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + ax = fig.add_subplot(spec[0, col_index]) + im = ax.imshow(plot_im, cmap=cmap, extent=extent) + + if aspect_ratio is not None: + if aspect_ratio == "auto": + aspect_ratio = extent[1] / extent[2] + if plot_line_profile: + aspect_ratio *= extent_line[2] / extent_line[1] + + ax.set_aspect(aspect_ratio) + cbar = False + + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + spec.tight_layout(fig) def show_slices( self, @@ -385,21 +404,14 @@ def show_slices( cmap = kwargs.pop("cmap", "magma") if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + rotated_object, vmin, vmax = return_scaled_histogram_ordering( + rotated_object, vmin=vmin, vmax=vmax + ) else: - vmax = None vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) + vmax = None spec = GridSpec( ncols=num_cols, From dfb4a699b1434cd9ca847652c0efc14bef22953a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 11:00:05 -0800 Subject: [PATCH 012/128] moved read-only self-attribute calls up-front. Might remove all-together, especially for positions Former-commit-id: fcb0d65c9d35766845a35a3df613b642287243e4 --- .../process/phase/iterative_base_class.py | 174 +++++++++++------- 1 file changed, 103 insertions(+), 71 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 9bfbca3ba..3aacf2831 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -194,8 +194,8 @@ def _preprocess_datacube_and_vacuum_probe( if com_shifts is not None: if np.isscalar(com_shifts[0]): com_shifts = ( - np.ones(self._datacube.Rshape) * com_shifts[0], - np.ones(self._datacube.Rshape) * com_shifts[1], + np.ones(datacube.Rshape) * com_shifts[0], + np.ones(datacube.Rshape) * com_shifts[1], ) if diffraction_intensities_shape is not None: @@ -368,8 +368,10 @@ def _extract_intensities_and_calibrations_from_datacube( If require_calibrations is False and calibrations are not set """ - # Copies intensities to device casting to float32 + # explicit read-only self attributes up-front xp = self._xp + verbose = self._verbose + energy = self._energy intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] @@ -388,7 +390,7 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - if self._verbose: + if verbose: warnings.warn( ( "Iterative reconstruction will not be quantitative unless you specify " @@ -426,10 +428,10 @@ def _extract_intensities_and_calibrations_from_datacube( self._angular_sampling = (force_angular_sampling,) * 2 self._angular_units = ("mrad",) * 2 - if self._energy is not None: + if energy is not None: self._reciprocal_sampling = ( force_angular_sampling - / electron_wavelength_angstrom(self._energy) + / electron_wavelength_angstrom(energy) / 1e3, ) * 2 self._reciprocal_units = ("A^-1",) * 2 @@ -439,10 +441,10 @@ def _extract_intensities_and_calibrations_from_datacube( self._reciprocal_sampling = (force_reciprocal_sampling,) * 2 self._reciprocal_units = ("A^-1",) * 2 - if self._energy is not None: + if energy is not None: self._angular_sampling = ( force_reciprocal_sampling - * electron_wavelength_angstrom(self._energy) + * electron_wavelength_angstrom(energy) * 1e3, ) * 2 self._angular_units = ("mrad",) * 2 @@ -454,7 +456,7 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - if self._verbose: + if verbose: warnings.warn( ( "Iterative reconstruction will not be quantitative unless you specify " @@ -473,11 +475,9 @@ def _extract_intensities_and_calibrations_from_datacube( self._reciprocal_sampling = (reciprocal_size,) * 2 self._reciprocal_units = ("A^-1",) * 2 - if self._energy is not None: + if energy is not None: self._angular_sampling = ( - reciprocal_size - * electron_wavelength_angstrom(self._energy) - * 1e3, + reciprocal_size * electron_wavelength_angstrom(energy) * 1e3, ) * 2 self._angular_units = ("mrad",) * 2 @@ -486,9 +486,9 @@ def _extract_intensities_and_calibrations_from_datacube( self._angular_sampling = (angular_size,) * 2 self._angular_units = ("mrad",) * 2 - if self._energy is not None: + if energy is not None: self._reciprocal_sampling = ( - angular_size / electron_wavelength_angstrom(self._energy) / 1e3, + angular_size / electron_wavelength_angstrom(energy) / 1e3, ) * 2 self._reciprocal_units = ("A^-1",) * 2 else: @@ -541,8 +541,10 @@ def _calculate_intensities_center_of_mass( Normalized vertical center of mass gradient """ + # explicit read-only self attributes up-front xp = self._xp asnumpy = self._asnumpy + reciprocal_sampling = self._reciprocal_sampling if com_measured: com_measured_x, com_measured_y = com_measured @@ -593,10 +595,10 @@ def _calculate_intensities_center_of_mass( # fix CoM units com_normalized_x = ( - xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + xp.nan_to_num(com_measured_x - com_fitted_x) * reciprocal_sampling[0] ) com_normalized_y = ( - xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + xp.nan_to_num(com_measured_y - com_fitted_y) * reciprocal_sampling[1] ) return ( @@ -676,8 +678,12 @@ def _solve_for_center_of_mass_relative_rotation( Summary statistics """ + # explicit read-only self attributes up-front xp = self._xp asnumpy = self._asnumpy + verbose = self._verbose + scan_sampling = self._scan_sampling + scan_units = self._scan_units if rotation_angles_deg is None: rotation_angles_deg = np.arange(-89.0, 90.0, 1.0) @@ -687,7 +693,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_rad = np.deg2rad(force_com_rotation) - if self._verbose: + if verbose: warnings.warn( ( "Best fit rotation forced to " @@ -701,7 +707,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = force_com_transpose - if self._verbose: + if verbose: warnings.warn( f"Transpose of intensities forced to {force_com_transpose}.", UserWarning, @@ -752,7 +758,7 @@ def _solve_for_center_of_mass_relative_rotation( else: _rotation_best_transpose = rotation_curl_transpose < rotation_curl - if self._verbose: + if verbose: if _rotation_best_transpose: print("Diffraction intensities should be transposed.") else: @@ -765,7 +771,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = force_com_transpose - if self._verbose: + if verbose: warnings.warn( f"Transpose of intensities forced to {force_com_transpose}.", UserWarning, @@ -860,7 +866,7 @@ def _solve_for_center_of_mass_relative_rotation( rotation_best_deg = rotation_angles_deg[ind_min] _rotation_best_rad = rotation_angles_rad[ind_min] - if self._verbose: + if verbose: print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) if plot_rotation: @@ -1014,8 +1020,9 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = True self._rotation_angles_deg = rotation_angles_deg + # Print summary - if self._verbose: + if verbose: print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) if _rotation_best_transpose: print("Diffraction intensities should be transposed.") @@ -1102,8 +1109,8 @@ def _solve_for_center_of_mass_relative_rotation( cmap = kwargs.pop("cmap", "RdBu_r") extent = [ 0, - self._scan_sampling[1] * _com_measured_x.shape[1], - self._scan_sampling[0] * _com_measured_x.shape[0], + scan_sampling[1] * _com_measured_x.shape[1], + scan_sampling[0] * _com_measured_x.shape[0], 0, ] @@ -1130,8 +1137,8 @@ def _solve_for_center_of_mass_relative_rotation( ], ): ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs) - ax.set_ylabel(f"x [{self._scan_units[0]}]") - ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_ylabel(f"x [{scan_units[0]}]") + ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) elif plot_center_of_mass == "default" or plot_center_of_mass is True: @@ -1140,8 +1147,8 @@ def _solve_for_center_of_mass_relative_rotation( extent = [ 0, - self._scan_sampling[1] * com_x.shape[1], - self._scan_sampling[0] * com_x.shape[0], + scan_sampling[1] * com_x.shape[1], + scan_sampling[0] * com_x.shape[0], 0, ] @@ -1160,8 +1167,8 @@ def _solve_for_center_of_mass_relative_rotation( ], ): ax.imshow(arr, extent=extent, cmap=cmap, **kwargs) - ax.set_ylabel(f"x [{self._scan_units[0]}]") - ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_ylabel(f"x [{scan_units[0]}]") + ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) return ( @@ -1206,10 +1213,13 @@ def _normalize_diffraction_intensities( Mean intensity value """ + # explicit read-only self attributes up-front xp = self._xp + asnumpy = self._asnumpy + mean_intensity = 0 - diffraction_intensities = self._asnumpy(diffraction_intensities) + diffraction_intensities = asnumpy(diffraction_intensities) if positions_mask is not None: number_of_patterns = np.count_nonzero(positions_mask.ravel()) else: @@ -1254,8 +1264,8 @@ def _normalize_diffraction_intensities( (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 ) - com_fitted_x = self._asnumpy(com_fitted_x) - com_fitted_y = self._asnumpy(com_fitted_y) + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) counter = 0 for rx in range(diffraction_intensities.shape[0]): @@ -1312,13 +1322,17 @@ def show_complex_CoM( default is scan sampling """ + # explicit read-only self attributes up-front + scan_sampling = self._scan_sampling + scan_units = self._scan_units + if com is None: com = (self.com_x, self.com_y) if pixelsize is None: - pixelsize = self._scan_sampling[0] + pixelsize = scan_sampling[0] if pixelunits is None: - pixelunits = self._scan_units[0] + pixelunits = scan_units[0] figsize = kwargs.pop("figsize", (6, 6)) fig, ax = plt.subplots(figsize=figsize) @@ -1607,7 +1621,10 @@ def _set_polar_parameters(self, parameters: dict): raise ValueError("{} not a recognized parameter".format(symbol)) def _calculate_scan_positions_in_pixels( - self, positions: np.ndarray, positions_mask + self, + positions: np.ndarray, + positions_mask, + object_padding_px, ): """ Method to compute the initial guess of scan positions in pixels. @@ -1619,16 +1636,25 @@ def _calculate_scan_positions_in_pixels( If None, a raster scan using experimental parameters is constructed. positions_mask: np.ndarray, optional Boolean real space mask to select positions in datacube to skip for reconstruction + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions Returns ------- positions_in_px: (J,2) np.ndarray Initial guess of scan positions in pixels + object_padding_px: Tupe[int,int] + Updated object_padding_px """ + # explicit read-only self attributes up-front grid_scan_shape = self._grid_scan_shape rotation_angle = self._rotation_best_rad + transpose = self._rotation_best_transpose step_sizes = self._scan_sampling + region_of_interest_shape = self._region_of_interest_shape + sampling = self.sampling if positions is None: if grid_scan_shape is not None: @@ -1643,45 +1669,47 @@ def _calculate_scan_positions_in_pixels( else: raise ValueError() - if self._rotation_best_transpose: - x = (x - np.ptp(x) / 2) / self.sampling[1] - y = (y - np.ptp(y) / 2) / self.sampling[0] + if transpose: + x = (x - np.ptp(x) / 2) / sampling[1] + y = (y - np.ptp(y) / 2) / sampling[0] else: - x = (x - np.ptp(x) / 2) / self.sampling[0] - y = (y - np.ptp(y) / 2) / self.sampling[1] + x = (x - np.ptp(x) / 2) / sampling[0] + y = (y - np.ptp(y) / 2) / sampling[1] x, y = np.meshgrid(x, y, indexing="ij") + if positions_mask is not None: x = x[positions_mask] y = y[positions_mask] else: positions -= np.mean(positions, axis=0) - x = positions[:, 0] / self.sampling[1] - y = positions[:, 1] / self.sampling[0] + x = positions[:, 0] / sampling[1] + y = positions[:, 1] / sampling[0] if rotation_angle is not None: x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin( rotation_angle ) + y * np.cos(rotation_angle) - if self._rotation_best_transpose: + if transpose: positions = np.array([y.ravel(), x.ravel()]).T else: positions = np.array([x.ravel(), y.ravel()]).T + positions -= np.min(positions, axis=0) - if self._object_padding_px is None: - float_padding = self._region_of_interest_shape / 2 - self._object_padding_px = (float_padding, float_padding) - elif np.isscalar(self._object_padding_px[0]): - self._object_padding_px = ( - (self._object_padding_px[0],) * 2, - (self._object_padding_px[1],) * 2, + if object_padding_px is None: + float_padding = region_of_interest_shape / 2 + object_padding_px = (float_padding, float_padding) + elif np.isscalar(object_padding_px[0]): + object_padding_px = ( + (object_padding_px[0],) * 2, + (object_padding_px[1],) * 2, ) - positions[:, 0] += self._object_padding_px[0][0] - positions[:, 1] += self._object_padding_px[1][0] + positions[:, 0] += object_padding_px[0][0] + positions[:, 1] += object_padding_px[1][0] - return positions + return positions, object_padding_px def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): """ @@ -1698,26 +1726,26 @@ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): out_array: (Px,Py) np.ndarray Summed array """ + # explicit read-only self attributes up-front xp = self._xp - x0 = xp.round(self._positions_px[:, 0]).astype("int") - y0 = xp.round(self._positions_px[:, 1]).astype("int") - + positions_px = self._positions_px roi_shape = self._region_of_interest_shape + object_shape = self._object_shape + + x0 = xp.round(positions_px[:, 0]).astype("int") + y0 = xp.round(positions_px[:, 1]).astype("int") + x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") flat_weights = patches.ravel() - indices = ( - (y0[:, None, None] + y_ind[None, None, :]) % self._object_shape[1] - ) + ( - (x0[:, None, None] + x_ind[None, :, None]) % self._object_shape[0] - ) * self._object_shape[ - 1 - ] + indices = ((y0[:, None, None] + y_ind[None, None, :]) % object_shape[1]) + ( + (x0[:, None, None] + x_ind[None, :, None]) % object_shape[0] + ) * object_shape[1] counts = xp.bincount( - indices.ravel(), weights=flat_weights, minlength=np.prod(self._object_shape) + indices.ravel(), weights=flat_weights, minlength=np.prod(object_shape) ) - return xp.reshape(counts, self._object_shape) + return xp.reshape(counts, object_shape) def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): """ @@ -1755,15 +1783,19 @@ def _extract_vectorized_patch_indices(self): self._vectorized_patch_indices_col: np.ndarray Column indices for probe patches inside object array """ + # explicit read-only self attributes up-front xp = self._xp - x0 = xp.round(self._positions_px[:, 0]).astype("int") - y0 = xp.round(self._positions_px[:, 1]).astype("int") - + positions_px = self._positions_px + positions_px = self._positions_px roi_shape = self._region_of_interest_shape + obj_shape = self._object_shape + + x0 = xp.round(positions_px[:, 0]).astype("int") + y0 = xp.round(positions_px[:, 1]).astype("int") + x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") - obj_shape = self._object_shape vectorized_patch_indices_row = ( x0[:, None, None] + x_ind[None, :, None] ) % obj_shape[0] From 188a6e3c594e3a25b622e1202802be57b33e093c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 12:03:37 -0800 Subject: [PATCH 013/128] cleaned up single-slice preprocess Former-commit-id: 6bdc8eb0b082da5bfea066a4e5693705d3fa00eb --- .../process/phase/iterative_base_class.py | 10 +- .../phase/iterative_ptychographic_methods.py | 112 ++++++++++++- .../iterative_singleslice_ptychography.py | 148 ++++++------------ 3 files changed, 165 insertions(+), 105 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 3aacf2831..e6e3a4547 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1185,8 +1185,8 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, - crop_patterns, positions_mask, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1199,11 +1199,11 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1256,9 +1256,9 @@ def _normalize_diffraction_intensities( crop_mask[-crop_w:, :crop_w] = True crop_mask[:crop_w:, -crop_w:] = True crop_mask[-crop_w:, -crop_w:] = True - self._crop_mask = crop_mask else: + crop_mask = None region_of_interest_shape = diffraction_intensities.shape[-2:] amplitudes = np.zeros( (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 @@ -1293,7 +1293,7 @@ def _normalize_diffraction_intensities( amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] - return amplitudes, mean_intensity + return amplitudes, mean_intensity, crop_mask def show_complex_CoM( self, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index a0f10fd0d..683905c88 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -4,7 +4,8 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.process.phase.utils import AffineTransform, rotate_point +from py4DSTEM.process.phase.utils import AffineTransform, ComplexProbe, rotate_point +from py4DSTEM.process.utils import get_CoM, get_shifted_ar from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import gaussian_filter, rotate @@ -19,6 +20,36 @@ class ObjectNDMethodsMixin: Mixin class for object methods applicable to 2D,2.5D, and 3D objects. """ + def _initialize_object( + self, + initial_object, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + def _crop_rotate_object_fov( self, array, @@ -619,6 +650,85 @@ class ProbeMethodsMixin: Mixin class for probe methods applicable to a single probe. """ + def _initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + device = self._device + crop_mask = self._crop_mask + region_of_interest_shape = self._region_of_interest_shape + sampling = self.sampling + energy = self._energy + rolloff = self._rolloff + polar_parameters = self._polar_parameters + + if initial_probe is None: + if vacuum_probe_intensity is not None: + semiangle_cutoff = np.inf + vacuum_probe_intensity = xp.asarray( + vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + vacuum_probe_intensity, + device=device, + ) + vacuum_probe_intensity = get_shifted_ar( + vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=device, + ) + + if crop_patterns: + vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( + region_of_interest_shape + ) + + _probe = ( + ComplexProbe( + gpts=region_of_interest_shape, + sampling=sampling, + energy=energy, + semiangle_cutoff=semiangle_cutoff, + rolloff=rolloff, + vacuum_probe_intensity=vacuum_probe_intensity, + parameters=polar_parameters, + device=device, + ) + .build() + ._array + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(_probe)) ** 2) + _probe *= xp.sqrt(mean_diffraction_intensity / probe_intensity) + + else: + if isinstance(initial_probe, ComplexProbe): + if initial_probe._gpts != region_of_interest_shape: + raise ValueError() + if hasattr(initial_probe, "_array"): + _probe = initial_probe._array + else: + initial_probe._xp = xp + _probe = initial_probe.build()._array + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(_probe)) ** 2) + _probe *= xp.sqrt(mean_diffraction_intensity / probe_intensity) + else: + _probe = xp.asarray(initial_probe, dtype=xp.complex64) + + return _probe, semiangle_cutoff + def _return_fourier_probe( self, probe=None, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 7434e2e50..22f2c2a46 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -36,7 +36,6 @@ polar_aliases, polar_symbols, ) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) @@ -291,13 +290,10 @@ def preprocess( ) ) - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) + if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + # preprocess datacube ( self._datacube, self._vacuum_probe_intensity, @@ -313,6 +309,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # calibrations self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, @@ -321,6 +318,13 @@ def preprocess( force_reciprocal_sampling=force_reciprocal_sampling, ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM ( self._com_measured_x, self._com_measured_y, @@ -335,6 +339,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # estimate rotation / transpose ( self._rotation_best_rad, self._rotation_best_transpose, @@ -356,57 +361,46 @@ def preprocess( **kwargs, ) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, + self._crop_mask, ) = self._normalize_diffraction_intensities( self._intensities, self._com_fitted_x, self._com_fitted_y, - crop_patterns, self._positions_mask, + crop_patterns, ) - # explicitly delete namespace + # explicitly delete intensities namespace self._num_diffraction_patterns = self._amplitudes.shape[0] self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, ) - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px, + self._object_type, + ) self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape + # center probe positions self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 @@ -420,74 +414,25 @@ def preprocess( self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # Vectorized Patches + # set vectorized patches ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) self._probe_initial = self._probe.copy() self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -500,22 +445,27 @@ def preprocess( shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + # initialize object_fov_mask if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - power=2, + power=power, chroma_boost=chroma_boost, ) From 68af4553844cc58d582e1df13e8030fa2f94f6d0 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 13:50:53 -0800 Subject: [PATCH 014/128] cleaned up multislice preprocess Former-commit-id: dfc8aeebf8900dc9034e710508fd5ff43aec0e57 --- .../iterative_multislice_ptychography.py | 242 ++++-------------- .../phase/iterative_ptychographic_methods.py | 131 +++++++++- .../iterative_singleslice_ptychography.py | 2 +- 3 files changed, 180 insertions(+), 195 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index bb46f25c8..92064871c 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -37,9 +37,7 @@ generate_batches, polar_aliases, polar_symbols, - spatial_frequencies, ) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) @@ -137,8 +135,8 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, - theta_x: float = 0, - theta_y: float = 0, + theta_x: float = None, + theta_y: float = None, middle_focus: bool = False, object_type: str = "complex", positions_mask: np.ndarray = None, @@ -244,92 +242,6 @@ def __init__( self._theta_x = theta_x self._theta_y = theta_y - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -340,7 +252,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None, @@ -432,13 +344,10 @@ def preprocess( ) ) - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) + if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + # preprocess datacube ( self._datacube, self._vacuum_probe_intensity, @@ -454,6 +363,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # calibrations self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, @@ -462,6 +372,13 @@ def preprocess( force_reciprocal_sampling=force_reciprocal_sampling, ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM ( self._com_measured_x, self._com_measured_y, @@ -476,6 +393,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # estimate rotation / transpose ( self._rotation_best_rad, self._rotation_best_transpose, @@ -497,15 +415,17 @@ def preprocess( **kwargs, ) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, + self._crop_mask, ) = self._normalize_diffraction_intensities( self._intensities, self._com_fitted_x, self._com_fitted_y, - crop_patterns, self._positions_mask, + crop_patterns, ) # explicitly delete namespace @@ -513,41 +433,29 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, ) - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] + # center probe positions self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 @@ -561,76 +469,25 @@ def preprocess( self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # Vectorized Patches + # set vectorized patches ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) self._probe_initial = self._probe.copy() self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -639,7 +496,7 @@ def preprocess( device=self._device, )._evaluate_ctf() - # Precomputed propagator arrays + # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, self.sampling, @@ -653,10 +510,12 @@ def preprocess( shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) @@ -664,11 +523,12 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - power=2, + power=power, chroma_boost=chroma_boost, ) @@ -681,7 +541,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - power=2, + power=power, chroma_boost=chroma_boost, ) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 683905c88..28a66d971 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1,11 +1,16 @@ -from typing import Tuple +from typing import Sequence, Tuple import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.process.phase.utils import AffineTransform, ComplexProbe, rotate_point -from py4DSTEM.process.utils import get_CoM, get_shifted_ar +from py4DSTEM.process.phase.utils import ( + AffineTransform, + ComplexProbe, + rotate_point, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import gaussian_filter, rotate @@ -187,6 +192,126 @@ class Object2p5DMethodsMixin: Overwrites ObjectNDMethodsMixin. """ + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + theta_x: float = None, + theta_y: float = None, + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + theta_x: float, optional + x tilt of propagator (in degrees) + theta_y: float, optional + y tilt of propagator (in degrees) + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + if theta_x is not None: + theta_x = np.deg2rad(theta_x) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + + if theta_y is not None: + theta_y = np.deg2rad(theta_y) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def _initialize_object( + self, + initial_object, + num_slices, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((num_slices, p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((num_slices, p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + def _return_projected_cropped_potential( self, ): diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 22f2c2a46..acd570424 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -198,7 +198,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None, From f2a8583d8c5bdfd13155bf7c4ee907caef8bd80f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 15:04:21 -0800 Subject: [PATCH 015/128] cleaned up mixedstate preprocess Former-commit-id: a2b5ba3e72d8f19b33f856975b18fdecb886cb06 --- .../iterative_mixedstate_ptychography.py | 161 +++++------------- .../phase/iterative_ptychographic_methods.py | 47 ++++- 2 files changed, 93 insertions(+), 115 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index b03bb456d..c5f0f450b 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -38,7 +38,6 @@ polar_aliases, polar_symbols, ) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) @@ -219,7 +218,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None, @@ -311,13 +310,10 @@ def preprocess( ) ) - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) + if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + # preprocess datacube ( self._datacube, self._vacuum_probe_intensity, @@ -333,6 +329,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # calibrations self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, @@ -341,6 +338,13 @@ def preprocess( force_reciprocal_sampling=force_reciprocal_sampling, ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM ( self._com_measured_x, self._com_measured_y, @@ -355,6 +359,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # estimate rotation / transpose ( self._rotation_best_rad, self._rotation_best_transpose, @@ -376,15 +381,17 @@ def preprocess( **kwargs, ) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, + self._crop_mask, ) = self._normalize_diffraction_intensities( self._intensities, self._com_fitted_x, self._com_fitted_y, - crop_patterns, self._positions_mask, + crop_patterns, ) # explicitly delete namespace @@ -392,41 +399,28 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, ) - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px, + self._object_type, + ) self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape + # center probe positions self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 @@ -440,89 +434,25 @@ def preprocess( self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # Vectorized Patches + # set vectorized patches ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - # Probe Initialization - if self._probe is None or isinstance(self._probe, ComplexProbe): - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - _probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - else: - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - _probe = self._probe._array - else: - self._probe._xp = xp - _probe = self._probe.build()._array - - self._probe = xp.zeros( - (self._num_probes,) + tuple(self._region_of_interest_shape), - dtype=xp.complex64, - ) - sx, sy = self._region_of_interest_shape - self._probe[0] = _probe - - # Randomly shift phase of other probes - for i_probe in range(1, self._num_probes): - shift_x = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) - shift_y = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) - self._probe[i_probe] = ( - self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) self._probe_initial = self._probe.copy() - self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -535,10 +465,12 @@ def preprocess( shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) @@ -546,11 +478,12 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - power=2, + power=power, chroma_boost=chroma_boost, ) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 28a66d971..a411baea9 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -775,7 +775,7 @@ class ProbeMethodsMixin: Mixin class for probe methods applicable to a single probe. """ - def _initialize_probe( + def initialize_probe( self, initial_probe, vacuum_probe_intensity, @@ -1037,6 +1037,51 @@ class ProbeMixedMethodsMixin: Overwrites ProbeMethodsMixin. """ + def initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ): + """ """ + + # explicit read-only self attributes up-front + xp = self._xp + num_probes = self._num_probes + region_of_interest_shape = self._region_of_interest_shape + + if initial_probe is None or isinstance(initial_probe, ComplexProbe): + # calls ProbeMethodsMixin for first probe + _probe, semiangle_cutoff = super().initialize_probe( + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ) + + sx, sy = region_of_interest_shape + _probes = xp.zeros((num_probes, sx, sy), dtype=xp.complex64) + _probes[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + _probes[i_probe] = ( + _probes[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + else: + _probes = xp.asarray(initial_probe, dtype=xp.complex64) + + return _probes, semiangle_cutoff + def show_fourier_probe( self, probe=None, From 78eebff2777fd9f32e51429411c88360db973cbe Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 15:09:46 -0800 Subject: [PATCH 016/128] underscore typos Former-commit-id: 2e231ce6377f9908fd83f209bb2b282cc1e5c128 --- py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 2 +- py4DSTEM/process/phase/iterative_ptychographic_methods.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index c5f0f450b..ab0de6258 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -441,7 +441,7 @@ def preprocess( ) = self._extract_vectorized_patch_indices() # initialize probe - self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, self._semiangle_cutoff = self.initialize_probe( self._probe, self._vacuum_probe_intensity, self._mean_diffraction_intensity, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index a411baea9..cbf3a24d2 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -775,7 +775,7 @@ class ProbeMethodsMixin: Mixin class for probe methods applicable to a single probe. """ - def initialize_probe( + def _initialize_probe( self, initial_probe, vacuum_probe_intensity, @@ -1037,7 +1037,7 @@ class ProbeMixedMethodsMixin: Overwrites ProbeMethodsMixin. """ - def initialize_probe( + def _initialize_probe( self, initial_probe, vacuum_probe_intensity, @@ -1054,7 +1054,7 @@ def initialize_probe( if initial_probe is None or isinstance(initial_probe, ComplexProbe): # calls ProbeMethodsMixin for first probe - _probe, semiangle_cutoff = super().initialize_probe( + _probe, semiangle_cutoff = super()._initialize_probe( initial_probe, vacuum_probe_intensity, mean_diffraction_intensity, From 5e0191d103bde9de985e097c668bfcf4850230d7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 15:10:33 -0800 Subject: [PATCH 017/128] underscore typos Former-commit-id: 37607e11155f91f5686a41596cf0c399d80c6ab8 --- py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index ab0de6258..c5f0f450b 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -441,7 +441,7 @@ def preprocess( ) = self._extract_vectorized_patch_indices() # initialize probe - self._probe, self._semiangle_cutoff = self.initialize_probe( + self._probe, self._semiangle_cutoff = self._initialize_probe( self._probe, self._vacuum_probe_intensity, self._mean_diffraction_intensity, From a804a5131de0b0f18d57ebcbbaee74531fd06e00 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 15:34:49 -0800 Subject: [PATCH 018/128] huh, not how super works Former-commit-id: d8e4d528b55da6fc8f3409c1d9b73aeccab09e51 --- py4DSTEM/process/phase/iterative_ptychographic_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index cbf3a24d2..decf6db26 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1054,7 +1054,10 @@ def _initialize_probe( if initial_probe is None or isinstance(initial_probe, ComplexProbe): # calls ProbeMethodsMixin for first probe - _probe, semiangle_cutoff = super()._initialize_probe( + # annoyingly can't use super() as Mixins are defined right->left + # but MRO is defined left->right.. + _probe, semiangle_cutoff = ProbeMethodsMixin._initialize_probe( + self, initial_probe, vacuum_probe_intensity, mean_diffraction_intensity, From 7de07d31295db89654f680d5223021e682d7f3f0 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 16:15:09 -0800 Subject: [PATCH 019/128] cleaned up mixed-multi slice preprocess Former-commit-id: db8e4a5a5d0b2afaefc53dcf3063eb8297cc9175 --- ...tive_mixedstate_multislice_ptychography.py | 252 ++++-------------- 1 file changed, 50 insertions(+), 202 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2369eea38..ecad16247 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -39,9 +39,7 @@ generate_batches, polar_aliases, polar_symbols, - spatial_frequencies, ) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) @@ -267,92 +265,6 @@ def __init__( self._theta_x = theta_x self._theta_y = theta_y - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -363,7 +275,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None, @@ -455,13 +367,10 @@ def preprocess( ) ) - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) + if self._positions_mask is not None: self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + # preprocess datacube ( self._datacube, self._vacuum_probe_intensity, @@ -477,6 +386,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # calibrations self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, @@ -485,6 +395,13 @@ def preprocess( force_reciprocal_sampling=force_reciprocal_sampling, ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM ( self._com_measured_x, self._com_measured_y, @@ -499,6 +416,7 @@ def preprocess( com_shifts=force_com_shifts, ) + # estimate rotation / transpose ( self._rotation_best_rad, self._rotation_best_transpose, @@ -520,15 +438,17 @@ def preprocess( **kwargs, ) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, + self._crop_mask, ) = self._normalize_diffraction_intensities( self._intensities, self._com_fitted_x, self._com_fitted_y, - crop_patterns, self._positions_mask, + crop_patterns, ) # explicitly delete namespace @@ -536,41 +456,29 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, ) - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] + # center probe positions self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 @@ -584,88 +492,25 @@ def preprocess( self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # Vectorized Patches + # set vectorized patches ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - # Probe Initialization - if self._probe is None or isinstance(self._probe, ComplexProbe): - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - _probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - else: - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - _probe = self._probe._array - else: - self._probe._xp = xp - _probe = self._probe.build()._array - - self._probe = xp.zeros( - (self._num_probes,) + tuple(self._region_of_interest_shape), - dtype=xp.complex64, - ) - sx, sy = self._region_of_interest_shape - self._probe[0] = _probe - - # Randomly shift phase of other probes - for i_probe in range(1, self._num_probes): - shift_x = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) - shift_y = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) - self._probe[i_probe] = ( - self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) self._probe_initial = self._probe.copy() - self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -674,7 +519,7 @@ def preprocess( device=self._device, )._evaluate_ctf() - # Precomputed propagator arrays + # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, self.sampling, @@ -688,10 +533,12 @@ def preprocess( shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) @@ -699,11 +546,12 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered[0], - power=2, + power=power, chroma_boost=chroma_boost, ) @@ -716,7 +564,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - power=2, + power=power, chroma_boost=chroma_boost, ) From c9c329924ee7fb6f11b36e34af092395dd372d97 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 31 Dec 2023 20:16:19 -0800 Subject: [PATCH 020/128] cleaned up overlap tomography preprocess Former-commit-id: c531a74d25a4a2b58a02bbfa180344908a15f1f2 --- .../phase/iterative_overlap_tomography.py | 402 +++++------------- .../phase/iterative_ptychographic_methods.py | 146 +++++++ 2 files changed, 241 insertions(+), 307 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 442cb79a0..f3abbb644 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -22,14 +22,17 @@ from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( + Object2p5DConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + Object2p5DMethodsMixin, Object3DMethodsMixin, ObjectNDMethodsMixin, + ProbeListMethodsMixin, ProbeMethodsMixin, ) from py4DSTEM.process.phase.utils import ( @@ -38,9 +41,7 @@ generate_batches, polar_aliases, polar_symbols, - spatial_frequencies, ) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) @@ -49,9 +50,12 @@ class OverlapTomographicReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, + Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + ProbeListMethodsMixin, ProbeMethodsMixin, Object3DMethodsMixin, + Object2p5DMethodsMixin, ObjectNDMethodsMixin, PtychographicReconstruction, ): @@ -116,7 +120,6 @@ class OverlapTomographicReconstruction( # Class-specific Metadata _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") - _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, @@ -198,7 +201,7 @@ def __init__( # Data self._datacube = datacube self._object = initial_object_guess - self._probe = initial_probe_guess + self._probe_init = initial_probe_guess # Common Metadata self._vacuum_probe_intensity = vacuum_probe_intensity @@ -219,164 +222,6 @@ def __init__( self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def _project_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(input_z / output_z).astype("int") - pad_size = voxels_per_slice * output_z - input_z - - padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) - - return xp.sum( - padded_array.reshape( - ( - -1, - voxels_per_slice, - ) - + array.shape[1:] - ), - axis=1, - ) - - def _expand_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(output_z / input_z).astype("int") - remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) - - voxels_in_slice = xp.repeat(voxels_per_slice, input_z) - voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice - - normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] - return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] - - def _rotate_zxy_volume( - self, - volume_array, - rot_matrix, - ): - """ """ - - xp = self._xp - affine_transform = self._affine_transform - swap_zxy_to_xyz = self._swap_zxy_to_xyz - - volume = volume_array.copy() - volume_shape = xp.asarray(volume.shape) - tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) - - in_center = (volume_shape - 1) / 2 - out_center = tf @ in_center - offset = in_center - out_center - - volume = affine_transform(volume, tf, offset=offset, order=3) - - return volume - def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -395,6 +240,7 @@ def preprocess( progress_bar: bool = True, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + main_tilt_axis: str = "vertical", **kwargs, ): """ @@ -440,7 +286,11 @@ def preprocess( Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering + If True, crop patterns to avoid wrap around of patterns when centering + main_tilt_axis: str + The default, 'vertical' (first scan dimension), results in object size (q,p,q), + 'horizontal' (second scan dimension) results in object size (p,p,q), + any other value (e.g. None) results in object size (max(p,q),p,q). Returns -------- @@ -465,7 +315,7 @@ def preprocess( ) if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") if self._positions_mask.ndim == 2: warnings.warn( @@ -476,47 +326,45 @@ def preprocess( self._positions_mask, (self._num_tilts, 1, 1) ) - if self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array."), - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_tilts - - # Prepopulate various arrays - - if self._positions_mask[0] is None: - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - - num_probes_per_tilt = np.array(num_probes_per_tilt) - else: num_probes_per_tilt = np.insert( self._positions_mask.sum(axis=(-2, -1)), 0, 0 ) - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) + else: + self._positions_mask = [None] * self._num_tilts + num_probes_per_tilt = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_tilt = np.array(num_probes_per_tilt) + # prepopulate relevant arrays self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) - self._rotation_best_transpose = diffraction_patterns_transpose + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if probe_roi_shape is not None: + roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) + self._region_of_interest_shape = np.array(roi_shape) + # TO-DO: generalize this if force_com_shifts is None: force_com_shifts = [None] * self._num_tilts + self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) + self._rotation_best_transpose = diffraction_patterns_transpose + + # loop over DPs for preprocessing for tilt_index in tqdmnd( self._num_tilts, desc="Preprocessing data", unit="tilt", disable=not progress_bar, ): + # preprocess datacube, vacuum and masks only for first tilt if tilt_index == 0: ( self._datacube[tilt_index], @@ -533,13 +381,6 @@ def preprocess( com_shifts=force_com_shifts[tilt_index], ) - self._amplitudes = xp.empty( - (self._num_diffraction_patterns,) + self._datacube[0].Qshape - ) - self._region_of_interest_shape = np.array( - self._amplitudes[0].shape[-2:] - ) - else: ( self._datacube[tilt_index], @@ -556,6 +397,7 @@ def preprocess( com_shifts=force_com_shifts[tilt_index], ) + # calibrations intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube[tilt_index], require_calibrations=True, @@ -564,6 +406,7 @@ def preprocess( force_reciprocal_sampling=force_reciprocal_sampling, ) + # calculate CoM ( com_measured_x, com_measured_y, @@ -578,19 +421,19 @@ def preprocess( com_shifts=force_com_shifts[tilt_index], ) + # corner-center amplitudes + idx_start = self._cum_probes_per_tilt[tilt_index] + idx_end = self._cum_probes_per_tilt[tilt_index + 1] ( - self._amplitudes[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ], + self._amplitudes[idx_start:idx_end], mean_diffraction_intensity_temp, + self._crop_mask, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, - crop_patterns, self._positions_mask[tilt_index], + crop_patterns, ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -605,12 +448,14 @@ def preprocess( com_normalized_y, ) - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], self._positions_mask[tilt_index] + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[tilt_index], + self._positions_mask[tilt_index], + self._object_padding_px, ) # handle semiangle specified in pixels @@ -619,119 +464,60 @@ def preprocess( self._semiangle_cutoff_pixels * self._angular_sampling[0] ) - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px_all, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - self._object = xp.zeros((q, p, q), dtype=xp.float32) - else: - self._object = xp.asarray(self._object, dtype=xp.float32) + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + main_tilt_axis, + ) self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] self._num_voxels = self._object.shape[0] - # Center Probes + # center probe positions self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) for tilt_index in range(self._num_tilts): - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] + idx_start = self._cum_probes_per_tilt[tilt_index] + idx_end = self._cum_probes_per_tilt[tilt_index + 1] + self._positions_px = self._positions_px_all[idx_start:idx_end] self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= ( self._positions_px_com - xp.array(self._object_shape) / 2 ) - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._positions_px.copy() + self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() self._positions_initial_all[:, 0] *= self.sampling[0] self._positions_initial_all[:, 1] *= self.sampling[1] - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity + for tilt_index in range(self._num_tilts): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[tilt_index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[tilt_index], + self._semiangle_cutoff, + crop_patterns, ) - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + self._probes_all.append(_probe) + self._probes_all_initial = _probe.copy() + self._probes_all_initial_aperture = xp.abs(xp.fft.fft2(_probe)) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + del self._probe_init + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -757,7 +543,9 @@ def preprocess( probe_overlap_3D = xp.zeros_like(self._object) old_rot_matrix = np.eye(3) # identity - for tilt_index in np.arange(self._num_tilts): + for tilt_index in range(self._num_tilts): + idx_start = self._cum_probes_per_tilt[tilt_index] + idx_end = self._cum_probes_per_tilt[tilt_index + 1] rot_matrix = self._tilt_orientation_matrices[tilt_index] probe_overlap_3D = self._rotate_zxy_volume( @@ -765,16 +553,12 @@ def preprocess( rot_matrix @ old_rot_matrix.T, ) - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] + self._positions_px = self._positions_px_all[idx_start:idx_end] self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) shifted_probes = fft_shift( - self._probe, self._positions_px_fractional, xp + self._probes_all[tilt_index], self._positions_px_fractional, xp ) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts( @@ -789,31 +573,35 @@ def preprocess( old_rot_matrix.T, ) - probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) + probe_overlap_3D_blurred = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( - probe_overlap_3D > 0.25 * probe_overlap_3D.max() + probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max() ) + else: self._object_fov_mask = np.asarray(object_fov_mask) + self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) + shifted_probes = fft_shift( + self._probes_all[0], self._positions_px_fractional, xp + ) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - power=2, + power=power, chroma_boost=chroma_boost, ) @@ -826,7 +614,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - power=2, + power=power, chroma_boost=chroma_boost, ) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index decf6db26..a2e968473 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -535,12 +535,14 @@ def show_slices( ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: rotated_object = np.abs( np.fft.fftshift( np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) ) ) + rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -618,6 +620,126 @@ class Object3DMethodsMixin: Overwrites ObjectNDMethodsMixin and Object2p5DMethodsMixin. """ + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + + def _project_sliced_object(self, array: np.ndarray, output_z): + """ + Projects voxel-sliced object. + + Parameters + ---------- + array: np.ndarray + 3D array to project + output_z: int + Output_dimension to project array to. + + Returns + ------- + projected_array: np.ndarray + projected array + """ + xp = self._xp + input_z = array.shape[0] + + voxels_per_slice = np.ceil(input_z / output_z).astype("int") + pad_size = voxels_per_slice * output_z - input_z + + padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) + + return xp.sum( + padded_array.reshape( + ( + -1, + voxels_per_slice, + ) + + array.shape[1:] + ), + axis=1, + ) + + def _expand_sliced_object(self, array: np.ndarray, output_z): + """ + Expands supersliced object. + + Parameters + ---------- + array: np.ndarray + 3D array to expand + output_z: int + Output_dimension to expand array to. + + Returns + ------- + expanded_array: np.ndarray + expanded array + """ + xp = self._xp + input_z = array.shape[0] + + voxels_per_slice = np.ceil(output_z / input_z).astype("int") + remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) + + voxels_in_slice = xp.repeat(voxels_per_slice, input_z) + voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice + + normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] + return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + order=3, + ): + """ """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=order) + + return volume + + def _initialize_object( + self, + initial_object, + positions_px, + object_type, + main_tilt_axis="vertical", + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + + if main_tilt_axis == "vertical": + _object = xp.zeros((q, p, q), dtype=xp.float32) + elif main_tilt_axis == "horizontal": + _object = xp.zeros((p, p, q), dtype=xp.float32) + else: + _object = xp.zeros((max(p, q), p, q), dtype=xp.float32) + else: + _object = xp.asarray(initial_object, dtype=xp.float32) + + return _object + def _crop_rotate_object_manually( self, array, @@ -769,6 +891,11 @@ def show_object_fft( **kwargs, ) + @property + def object_supersliced(self): + """Returns super-sliced object""" + return self._project_sliced_object(self._object, self._num_slices) + class ProbeMethodsMixin: """ @@ -1152,3 +1279,22 @@ def show_fourier_probe( chroma_boost=chroma_boost, **kwargs, ) + + +class ProbeListMethodsMixin: + """ + Mixin class for probe methods unique to a list of single probes. + Overwrites ProbeMethodsMixin. + """ + + @property + def _probe(self): + """Dummy property to return average probe""" + + xp = self._xp + probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) + + for pr in self._probes_all: + probe += pr + + return probe / self._num_tilts From d251c97f597822ffe40024635fbeab65d7847eb7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 09:15:06 -0800 Subject: [PATCH 021/128] removing redundant tune func from parallax Former-commit-id: a24b9250a18db0b9c081181355ec3ba803605e0f --- py4DSTEM/process/phase/iterative_parallax.py | 159 ------------------- 1 file changed, 159 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 0bcd92240..a7251c9c7 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -736,165 +736,6 @@ def preprocess( return self - def tune_angle_and_defocus( - self, - angle_guess=None, - defocus_guess=None, - angle_step_size=5, - defocus_step_size=100, - num_angle_values=5, - num_defocus_values=5, - return_values=False, - plot_reconstructions=True, - plot_convergence=True, - **kwargs, - ): - """ - Run parallax reconstruction over a parameters space of pre-determined angles - and defocus - - Parameters - ---------- - angle_guess: float (degrees), optional - initial starting guess for rotation angle between real and reciprocal space - if None, uses 0 - defocus_guess: float (A), optional - initial starting guess for defocus (defocus dF) - if None, uses 0 - angle_step_size: float (degrees), optional - size of change of rotation angle between real and reciprocal space for - each step in parameter space - defocus_step_size: float (A), optional - size of change of defocus for each step in parameter space - num_angle_values: int, optional - number of values of angle to test, must be >= 1. - num_defocus_values: int,optional - number of values of defocus to test, must be >= 1 - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, makes 2D plot of error metrix - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - asnumpy = self._asnumpy - - if angle_guess is None: - angle_guess = 0 - if defocus_guess is None: - defocus_guess = 0 - - if num_angle_values == 1: - angle_step_size = 0 - - if num_defocus_values == 1: - defocus_step_size = 0 - - angles = np.linspace( - angle_guess - angle_step_size * (num_angle_values - 1) / 2, - angle_guess + angle_step_size * (num_angle_values - 1) / 2, - num_angle_values, - ) - - defocus_values = np.linspace( - defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2, - defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2, - num_defocus_values, - ) - if return_values or plot_convergence: - recon_BF = [] - convergence = [] - - if plot_reconstructions: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 4 * num_angle_values) - ) - - fig = plt.figure(figsize=figsize) - - # run loop and plot along the way - self._verbose = False - for flat_index, (angle, defocus) in enumerate( - tqdmnd(angles, defocus_values, desc="Tuning angle and defocus") - ): - self.preprocess( - defocus_guess=defocus, - rotation_guess=angle, - plot_average_bf=False, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_angle_values, num_defocus_values) - ) - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_figax( - fig, - ax=object_ax, - **kwargs, - ) - - object_ax.set_title( - f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self._recon_error[0]:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - recon_BF.append(self.recon_BF) - if return_values or plot_convergence: - convergence.append(asnumpy(self._recon_error[0])) - - if plot_convergence: - fig, ax = plt.subplots() - ax.set_title("convergence") - im = ax.imshow( - np.array(convergence).reshape(angles.shape[0], defocus_values.shape[0]), - cmap="magma", - ) - - if angles.shape[0] > 1: - ax.set_ylabel("angles") - ax.set_yticks(np.arange(angles.shape[0])) - ax.set_yticklabels([f"{angle:.1f} °" for angle in angles]) - else: - ax.set_yticks([]) - ax.set_ylabel(f"angle {angles[0]:.1f}") - - if defocus_values.shape[0] > 1: - ax.set_xlabel("defocus values") - ax.set_xticks(np.arange(defocus_values.shape[0])) - ax.set_xticklabels([f"{df:.1f}" for df in defocus_values]) - else: - ax.set_xticks([]) - ax.set_xlabel(f"defocus value: {defocus_values[0]:.1f}") - - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - fig.colorbar(im, cax=cax) - - fig.tight_layout() - - if return_values: - convergence = np.array(convergence).reshape( - angles.shape[0], defocus_values.shape[0] - ) - return recon_BF, convergence - def reconstruct( self, max_alignment_bin: int = None, From 606f8d7ddc60f58be183500bb618f3ea6ae8e294 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 09:57:01 -0800 Subject: [PATCH 022/128] moved single-slice forward and adjoint to methods.py Former-commit-id: 01d6ad6d6fb2bb0628d886e47115c35b1fbdb5ff --- .../phase/iterative_ptychographic_methods.py | 457 ++++++++++++++++++ .../iterative_singleslice_ptychography.py | 451 +---------------- 2 files changed, 459 insertions(+), 449 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index a2e968473..ac16d89a5 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -7,6 +7,7 @@ from py4DSTEM.process.phase.utils import ( AffineTransform, ComplexProbe, + fft_shift, rotate_point, spatial_frequencies, ) @@ -1298,3 +1299,459 @@ def _probe(self): probe += pr return probe / self._num_tilts + + +class ObjectNDProbeMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. + """ + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + ] + + overlap = shifted_probes * object_patches + + return shifted_probes, object_patches, overlap + + def _gradient_descent_fourier_projection(self, amplitudes, overlap): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + fourier_overlap = xp.fft.fft2(overlap) + error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) + + fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_overlap - overlap + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + fourier_overlap = xp.fft.fft2(overlap) + error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + fourier_projected_factor = amplitudes * xp.exp( + 1j * xp.angle(fourier_projected_factor) + ) + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + object * probe overlap + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + shifted_probes, object_patches, overlap = self._overlap_projection( + current_object, current_probe + ) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + overlap, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, overlap + ) + + return shifted_probes, object_patches, overlap, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2 + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ) + ) + * probe_normalization + ) + else: + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2 + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ) + ) + * probe_normalization + ) + else: + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index acd570424..7d054a6de 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -27,6 +27,7 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) from py4DSTEM.process.phase.utils import ( @@ -44,6 +45,7 @@ class SingleslicePtychographicReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ObjectNDMethodsMixin, PtychographicReconstruction, @@ -524,455 +526,6 @@ def preprocess( return self - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * object_patches - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - def _constraints( self, current_object, From 95a86a75d0112c9dc6bfe8986a80c7110f767d9d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 13:20:31 -0800 Subject: [PATCH 023/128] added necessary multi-slice forward and adjoint methods Former-commit-id: cda0cc399dfaf332c32465c3815a9ce4fb63130d --- .../iterative_multislice_ptychography.py | 500 +----------------- .../phase/iterative_ptychographic_methods.py | 260 +++++++++ 2 files changed, 264 insertions(+), 496 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 92064871c..868f5a246 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -28,7 +28,9 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) from py4DSTEM.process.phase.utils import ( @@ -47,6 +49,8 @@ class MultislicePtychographicReconstruction( ProbeConstraintsMixin, Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, Object2p5DMethodsMixin, ObjectNDMethodsMixin, @@ -615,502 +619,6 @@ def preprocess( return self - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts(xp.conj(probe) * exit_waves) - * probe_normalization - ) - - # back-transmit - exit_waves *= xp.conj(obj) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.conj(probe) * exit_waves_copy - ) - * probe_normalization - ) - - # back-transmit - exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - def _position_correction( self, current_object, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index ac16d89a5..40677b67d 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1755,3 +1755,263 @@ def _adjoint( ) return current_object, current_probe + + +class Object2p5DProbeMethodsMixin: + """ + Mixin class for methods unique to 2.5D objects using a single probe. + """ + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + shifted_probes = xp.empty_like(object_patches) + shifted_probes[0] = fft_shift(current_probe, self._positions_px_fractional, xp) + + for s in range(self._num_slices): + # transmit + overlap = object_patches[s] * shifted_probes[s] + + # propagate + if s + 1 < self._num_slices: + shifted_probes[s + 1] = self._propagate_array( + overlap, self._propagator_arrays[s] + ) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(probe) ** 2 + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object[s] += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) + ) + * probe_normalization + ) + else: + current_object[s] += step_size * ( + self._sum_overlapping_patches_bincounts(xp.conj(probe) * exit_waves) + * probe_normalization + ) + + # back-transmit + exit_waves *= xp.conj(obj) + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(probe) ** 2 + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object[s] = ( + self._sum_overlapping_patches_bincounts( + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) + ) + * probe_normalization + ) + else: + current_object[s] = ( + self._sum_overlapping_patches_bincounts( + xp.conj(probe) * exit_waves_copy + ) + * probe_normalization + ) + + # back-transmit + exit_waves_copy *= xp.conj(obj) + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe From 472e8b621c3a9954363879955fc5cae23106f12a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 13:38:30 -0800 Subject: [PATCH 024/128] added necessary mixed-state forward and adjoint methods Former-commit-id: 39d1f84e1ad6bc5f7adbbdb33f87190f26417ca8 --- .../iterative_mixedstate_ptychography.py | 461 +----------------- .../phase/iterative_ptychographic_methods.py | 331 +++++++++++++ 2 files changed, 335 insertions(+), 457 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index c5f0f450b..9ff1fd976 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -28,6 +28,8 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ObjectNDProbeMixedMethodsMixin, ProbeMethodsMixin, ProbeMixedMethodsMixin, ) @@ -47,6 +49,8 @@ class MixedstatePtychographicReconstruction( ProbeMixedConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ObjectNDProbeMixedMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMixedMethodsMixin, ProbeMethodsMixin, ObjectNDMethodsMixin, @@ -543,463 +547,6 @@ def preprocess( return self - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm - - fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf - - amplitude_modification = amplitudes / intensity_norm_projected - fourier_projected_factor *= amplitude_modification[:, None] - - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = xp.zeros_like(current_object) - object_update = xp.zeros_like(current_object) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 - ) - if self._object_type == "potential": - object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - else: - object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object += object_update * probe_normalization - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = xp.zeros_like(current_object) - current_object = xp.zeros_like(current_object) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 - ) - if self._object_type == "potential": - current_object += self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - else: - current_object += self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object *= probe_normalization - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - def _constraints( self, current_object, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 40677b67d..85ee9b5f1 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1760,6 +1760,7 @@ def _adjoint( class Object2p5DProbeMethodsMixin: """ Mixin class for methods unique to 2.5D objects using a single probe. + Overwrites ObjectNDProbeMethodsMixin. """ def _overlap_projection(self, current_object, current_probe): @@ -2015,3 +2016,333 @@ def _projection_sets_adjoint( ) return current_object, current_probe + + +class ObjectNDProbeMixedMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using mixed probes. + Overwrites ObjectNDProbeMethodsMixin. + """ + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + ] + + overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_fourier_projection(self, amplitudes, overlap): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_overlap - overlap + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = xp.zeros_like(current_object) + object_update = xp.zeros_like(current_object) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes[:, i_probe]) ** 2 + ) + if self._object_type == "potential": + object_update += step_size * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + else: + object_update += step_size * self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object += object_update * probe_normalization + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = xp.zeros_like(current_object) + current_object = xp.zeros_like(current_object) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes[:, i_probe]) ** 2 + ) + if self._object_type == "potential": + current_object += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + else: + current_object += self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object *= probe_normalization + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe From 7ac1c83b842ace5f5b8418a5fe69363b914d90dc Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 13:52:46 -0800 Subject: [PATCH 025/128] added necessary mixed-state-multi-slce forward and adjoint methods Former-commit-id: c1481dd4f8de640d206fce69c8c06a4012cec738 --- ...tive_mixedstate_multislice_ptychography.py | 542 +----------------- .../phase/iterative_ptychographic_methods.py | 291 ++++++++++ 2 files changed, 297 insertions(+), 536 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index ecad16247..5c6379d60 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -29,7 +29,10 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( Object2p5DMethodsMixin, + Object2p5DProbeMixedMethodsMixin, ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ObjectNDProbeMixedMethodsMixin, ProbeMethodsMixin, ProbeMixedMethodsMixin, ) @@ -50,6 +53,9 @@ class MixedstateMultislicePtychographicReconstruction( ProbeConstraintsMixin, Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + Object2p5DProbeMixedMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMixedMethodsMixin, ProbeMethodsMixin, Object2p5DMethodsMixin, @@ -638,542 +644,6 @@ def preprocess( return self - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ] - - num_probe_positions = object_patches.shape[1] - - propagated_shape = ( - self._num_slices, - num_probe_positions, - self._num_probes, - self._region_of_interest_shape[0], - self._region_of_interest_shape[1], - ) - propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = ( - xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm - - fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves - modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf - - amplitude_modification = amplitudes / intensity_norm_projected - fourier_projected_factor *= amplitude_modification[:, None] - - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = xp.zeros_like(current_object[s]) - object_update = xp.zeros_like(current_object[s]) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 - ) - - if self._object_type == "potential": - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - ) - else: - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] - ) - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] += object_update * probe_normalization - - # back-transmit - exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = xp.zeros_like(current_object[s]) - object_update = xp.zeros_like(current_object[s]) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 - ) - - if self._object_type == "potential": - object_update += self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves_copy[:, i_probe] - ) - ) - else: - object_update += self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] = object_update * probe_normalization - - # back-transmit - exit_waves_copy *= xp.expand_dims( - xp.conj(obj), axis=1 - ) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - def _position_correction( self, current_object, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 85ee9b5f1..0a3dd5e26 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -2346,3 +2346,294 @@ def _projection_sets_adjoint( ) return current_object, current_probe + + +class Object2p5DProbeMixedMethodsMixin: + """ + Mixin class for methods unique to 2.5D objects using mixed probes. + Overwrites ObjectNDProbeMethodsMixin and ObjectNDProbeMixedMethodsMixin. + """ + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + shifted_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + + shifted_probes = xp.empty(shifted_shape, dtype=object_patches.dtype) + shifted_probes[0] = fft_shift(current_probe, self._positions_px_fractional, xp) + + for s in range(self._num_slices): + # transmit + overlap = xp.expand_dims(object_patches[s], axis=1) * shifted_probes[s] + + # propagate + if s + 1 < self._num_slices: + shifted_probes[s + 1] = self._propagate_array( + overlap, self._propagator_arrays[s] + ) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe From ee28b2ce98e9b6be4bd4abf484ccd9c991a51712 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 15:05:06 -0800 Subject: [PATCH 026/128] removed redundant forward and adjoint methods from overlap tomo Former-commit-id: c8a2a840afd7b2db8a2c13ace083804f1e5ad00e --- .../phase/iterative_overlap_tomography.py | 482 +----------------- 1 file changed, 4 insertions(+), 478 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index f3abbb644..1ef2d54cb 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -30,8 +30,10 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, Object3DMethodsMixin, ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeListMethodsMixin, ProbeMethodsMixin, ) @@ -52,6 +54,8 @@ class OverlapTomographicReconstruction( Object3DConstraintsMixin, Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeListMethodsMixin, ProbeMethodsMixin, Object3DMethodsMixin, @@ -691,484 +695,6 @@ def preprocess( return self - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - complex_object = xp.exp(1j * current_object) - object_patches = complex_object[ - :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes:np.ndarray - Prop[object^n*probe^n] - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - ( - exit_waves[self._active_tilt_index], - error, - ) = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves[self._active_tilt_index], - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - # back-transmit - exit_waves *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) - ) - * probe_normalization - ) - - # back-transmit - exit_waves_copy *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves[self._active_tilt_index], - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - def _position_correction( self, current_object, From d56a4b6302968b668c4b055d8d74bff88b6e26d3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 16:48:59 -0800 Subject: [PATCH 027/128] cleaned up single-slice reconstruct Former-commit-id: f87a0b009e16bb322362f84dbe0e3a1289a40b50 --- .../process/phase/iterative_base_class.py | 180 ++++++++++++++ .../phase/iterative_ptychographic_methods.py | 42 ++++ .../iterative_singleslice_ptychography.py | 234 ++++-------------- 3 files changed, 273 insertions(+), 183 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index e6e3a4547..682cb87ff 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1805,6 +1805,186 @@ def _extract_vectorized_patch_indices(self): return vectorized_patch_indices_row, vectorized_patch_indices_col + def _set_reconstruction_method_parameters( + self, + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ): + """""" + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + return ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) + + def _report_reconstruction_summary( + self, + max_iter, + switch_object_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ): + """ """ + + # object type + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + + # stochastic gradient descent + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + # named projection set method + if reconstruction_parameter is not None: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + + # generalized projections (or the even more rare charge-flipping) + elif projection_a is not None: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): " + f"{projection_a, projection_b, projection_c}." + ) + ) + + # gradient descent + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + def _position_correction( self, relevant_object, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 0a3dd5e26..172eacd57 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1,3 +1,4 @@ +import warnings from typing import Sequence, Tuple import matplotlib.pyplot as plt @@ -20,6 +21,8 @@ except (ModuleNotFoundError, ImportError): cp = np +warnings.simplefilter(action="always", category=UserWarning) + class ObjectNDMethodsMixin: """ @@ -1756,6 +1759,45 @@ def _adjoint( return current_object, current_probe + def _reset_reconstruction( + self, + store_iterations, + reset, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._object_type = self._object_type_initial + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + self._exit_waves = None + class Object2p5DProbeMethodsMixin: """ diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 7d054a6de..286b51491 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -633,11 +633,13 @@ def _constraints( Constrained positions estimate """ + # object constraints + + # smoothness if gaussian_filter: current_object = self._object_gaussian_constraint( current_object, gaussian_filter_sigma, pure_phase_object ) - if butterworth_filter: current_object = self._object_butterworth_constraint( current_object, @@ -645,12 +647,12 @@ def _constraints( q_highpass, butterworth_order, ) - if tv_denoise: current_object = self._object_denoise_tv_pylops( current_object, tv_denoise_weight, tv_denoise_inner_iter ) + # L1-norm pushing vacuum to zero if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -658,6 +660,7 @@ def _constraints( object_mask, ) + # amplitude threshold (complex) or positivity (potential) if self._object_type == "complex": current_object = self._object_threshold_constraint( current_object, pure_phase_object @@ -665,9 +668,13 @@ def _constraints( elif object_positivity: current_object = self._object_positivity_constraint(current_object) + # probe constraints + + # CoM corner-centering if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) + # Fourier amplitude (aperture) constraints if fix_probe_aperture: current_probe = self._probe_aperture_constraint( current_probe, @@ -680,6 +687,7 @@ def _constraints( constrain_probe_fourier_amplitude_constant_intensity, ) + # Fourier phase (aberrations) fitting if fit_probe_aberrations: current_probe = self._probe_aberration_fitting_constraint( current_probe, @@ -687,6 +695,7 @@ def _constraints( fit_probe_aberrations_max_radial_order, ) + # Real-space amplitude constraint if constrain_probe_amplitude: current_probe = self._probe_amplitude_constraint( current_probe, @@ -694,11 +703,15 @@ def _constraints( constrain_probe_amplitude_relative_width, ) + # position constraints if not fix_positions: + # CoM centering current_positions = self._positions_center_of_mass_constraint( current_positions ) + # global affine transformation + # TO-DO: generalize to higher-order basis? if global_affine_transformation: current_positions = self._positions_affine_transformation_constraint( self._positions_px_initial, current_positions @@ -860,157 +873,39 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) + self._report_reconstruction_summary( + max_iter, + switch_object_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) - # Batching + # batching shuffled_indices = np.arange(self._num_diffraction_patterns) unshuffled_indices = np.zeros_like(shuffled_indices) @@ -1020,38 +915,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset) # main loop for a0 in tqdmnd( @@ -1066,16 +930,18 @@ def reconstruct( if self._object_type == "potential": self._object_type = "complex" self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": + else: self._object_type = "potential" self._object = xp.angle(self._object) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( self._num_diffraction_patterns ) + positions_px = self._positions_px.copy()[shuffled_indices] for start, end in generate_batches( @@ -1086,6 +952,7 @@ def reconstruct( self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) + ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, @@ -1185,6 +1052,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) From 92d642480986f92f9836a99876f6289c948a2b63 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 16:57:40 -0800 Subject: [PATCH 028/128] cleaned up multi-slice reconstruct Former-commit-id: b3c92d94324f54804d5d2a3e0bdb542863ce1fec --- .../iterative_multislice_ptychography.py | 230 ++++-------------- .../iterative_singleslice_ptychography.py | 2 +- 2 files changed, 42 insertions(+), 190 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 868f5a246..64b79f501 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1004,7 +1004,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -1037,7 +1037,7 @@ def reconstruct( q_highpass: float = None, butterworth_order: float = 2, kz_regularization_filter_iter: int = np.inf, - kz_regularization_gamma: Union[float, np.ndarray] = None, + kz_regularization_gamma: float = None, identical_slices_iter: int = 0, object_positivity: bool = True, shrinkage_rad: float = 0.0, @@ -1174,155 +1174,37 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) + self._report_reconstruction_summary( + max_iter, + switch_object_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) # Batching shuffled_indices = np.arange(self._num_diffraction_patterns) @@ -1334,38 +1216,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset) # main loop for a0 in tqdmnd( @@ -1380,16 +1231,18 @@ def reconstruct( if self._object_type == "potential": self._object_type = "complex" self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": + else: self._object_type = "potential" self._object = xp.angle(self._object) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( self._num_diffraction_patterns ) + positions_px = self._positions_px.copy()[shuffled_indices] for start, end in generate_batches( @@ -1400,6 +1253,7 @@ def reconstruct( self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) + ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, @@ -1408,9 +1262,9 @@ def reconstruct( # forward operator ( - propagated_probes, + shifted_probes, object_patches, - self._transmitted_probes, + overlap, self._exit_waves, batch_error, ) = self._forward( @@ -1429,7 +1283,7 @@ def reconstruct( self._object, self._probe, object_patches, - propagated_probes, + shifted_probes, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1488,10 +1342,7 @@ def reconstruct( butterworth_order=butterworth_order, kz_regularization_filter=a0 < kz_regularization_filter_iter and kz_regularization_gamma is not None, - kz_regularization_gamma=kz_regularization_gamma[a0] - if kz_regularization_gamma is not None - and isinstance(kz_regularization_gamma, np.ndarray) - else kz_regularization_gamma, + kz_regularization_gamma=kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, @@ -1510,6 +1361,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 286b51491..02330638f 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -721,7 +721,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, From 6f69e6da71fa80e57d460dcb8fb013769e666378 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 17:04:46 -0800 Subject: [PATCH 029/128] cleaned up mixed-state reconstruct Former-commit-id: 65c8d07c558cc945acb7d38333b60fa1685edaae --- .../iterative_mixedstate_ptychography.py | 217 +++--------------- 1 file changed, 36 insertions(+), 181 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 9ff1fd976..911961bf0 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -719,7 +719,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -871,155 +871,37 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) + self._report_reconstruction_summary( + max_iter, + switch_object_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) # Batching shuffled_indices = np.arange(self._num_diffraction_patterns) @@ -1031,38 +913,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset) # main loop for a0 in tqdmnd( @@ -1077,16 +928,18 @@ def reconstruct( if self._object_type == "potential": self._object_type = "complex" self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": + else: self._object_type = "potential" self._object = xp.angle(self._object) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( self._num_diffraction_patterns ) + positions_px = self._positions_px.copy()[shuffled_indices] for start, end in generate_batches( @@ -1097,6 +950,7 @@ def reconstruct( self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) + ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, @@ -1197,6 +1051,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) From 32d34d7eb2d7b095ce7d7add7a0f7e991ae32bd8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 17:08:46 -0800 Subject: [PATCH 030/128] cleaned up mixed-state multi-slice reconstruct Former-commit-id: 09ec6bca64eb4b90d4f1a63f3471d96cccec49f8 --- ...tive_mixedstate_multislice_ptychography.py | 227 ++++-------------- .../iterative_multislice_ptychography.py | 2 +- 2 files changed, 42 insertions(+), 187 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 5c6379d60..e263b2dc2 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -1019,7 +1019,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -1188,157 +1188,39 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) + self._report_reconstruction_summary( + max_iter, + switch_object_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) - # Batching + # batching shuffled_indices = np.arange(self._num_diffraction_patterns) unshuffled_indices = np.zeros_like(shuffled_indices) @@ -1348,38 +1230,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset) # main loop for a0 in tqdmnd( @@ -1394,16 +1245,18 @@ def reconstruct( if self._object_type == "potential": self._object_type = "complex" self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": + else: self._object_type = "potential" self._object = xp.angle(self._object) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( self._num_diffraction_patterns ) + positions_px = self._positions_px.copy()[shuffled_indices] for start, end in generate_batches( @@ -1414,6 +1267,7 @@ def reconstruct( self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) + ( self._vectorized_patch_indices_row, self._vectorized_patch_indices_col, @@ -1422,9 +1276,9 @@ def reconstruct( # forward operator ( - propagated_probes, + shifted_probes, object_patches, - self._transmitted_probes, + overlap, self._exit_waves, batch_error, ) = self._forward( @@ -1443,7 +1297,7 @@ def reconstruct( self._object, self._probe, object_patches, - propagated_probes, + shifted_probes, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1456,7 +1310,7 @@ def reconstruct( positions_px[start:end] = self._position_correction( self._object, self._probe[0], - self._transmitted_probes[:, 0], + overlap[:, 0], amplitudes, self._positions_px, positions_step_size, @@ -1525,6 +1379,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 64b79f501..a16ab2db5 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1296,7 +1296,7 @@ def reconstruct( positions_px[start:end] = self._position_correction( self._object, self._probe, - self._transmitted_probes, + overlap, amplitudes, self._positions_px, positions_step_size, From 79fb7ee4d1101b88b9e597a337938c5b4119a39c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 18:31:00 -0800 Subject: [PATCH 031/128] cleaned up overlap tomo reconstruct, different probes per tilt Former-commit-id: 14c1e66341db84c695401ce824cedb055ce4ec1e --- .../phase/iterative_overlap_tomography.py | 242 ++++-------------- .../phase/iterative_ptychographic_methods.py | 125 ++++++--- 2 files changed, 142 insertions(+), 225 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 1ef2d54cb..782739d74 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -516,8 +516,8 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial = _probe.copy() - self._probes_all_initial_aperture = xp.abs(xp.fft.fft2(_probe)) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init @@ -1034,7 +1034,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -1181,148 +1181,41 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: - raise NotImplementedError( - "Position correction is currently incompatible with collective updates." + self._report_reconstruction_summary( + max_iter, + np.inf, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, ) - # Batching + # batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) if max_batch_size is not None: xp.random.seed(seed_random) @@ -1330,37 +1223,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) # main loop for a0 in tqdmnd( @@ -1393,6 +1256,12 @@ def reconstruct( object_sliced = self._project_sliced_object( self._object, self._num_slices ) + + _probe = self._probes_all[self._active_tilt_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_tilt_index + ] + if not use_projection_scheme: object_sliced_old = object_sliced.copy() @@ -1440,14 +1309,14 @@ def reconstruct( # forward operator ( - propagated_probes, + shifted_probes, object_patches, - transmitted_probes, + overlap, self._exit_waves, batch_error, ) = self._forward( object_sliced, - self._probe, + _probe, amplitudes, self._exit_waves, use_projection_scheme, @@ -1457,11 +1326,11 @@ def reconstruct( ) # adjoint operator - object_sliced, self._probe = self._adjoint( + object_sliced, _probe = self._adjoint( object_sliced, - self._probe, + _probe, object_patches, - propagated_probes, + shifted_probes, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1473,8 +1342,8 @@ def reconstruct( if a0 >= fix_positions_iter: positions_px[start:end] = self._position_correction( object_sliced, - self._probe, - transmitted_probes, + _probe, + overlap, amplitudes, self._positions_px, positions_step_size, @@ -1514,11 +1383,11 @@ def reconstruct( if not collective_tilt_updates: ( self._object, - self._probe, + _probe, self._positions_px_all[start_tilt:end_tilt], ) = self._constraints( self._object, - self._probe, + _probe, self._positions_px_all[start_tilt:end_tilt], fix_com=fix_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter @@ -1535,7 +1404,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, + initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1568,11 +1437,11 @@ def reconstruct( ( self._object, - self._probe, + _probe, _, ) = self._constraints( self._object, - self._probe, + _probe, None, fix_com=fix_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter @@ -1589,7 +1458,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, + initial_probe_aperture=_probe_initial_aperture, fix_positions=True, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1612,6 +1481,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 172eacd57..f58507cf2 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -174,6 +174,45 @@ def show_object_fft(self, obj=None, **kwargs): **kwargs, ) + def _reset_reconstruction( + self, + store_iterations, + reset, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._object_type = self._object_type_initial + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + self._exit_waves = None + @property def object_fft(self): """Fourier transform of current object estimate""" @@ -1291,6 +1330,53 @@ class ProbeListMethodsMixin: Overwrites ProbeMethodsMixin. """ + def _reset_reconstruction( + self, + store_iterations, + reset, + use_projection_scheme, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probes_all = [pr.copy() for pr in self._probes_all_initial] + self._positions_px_all = self._positions_px_initial_all.copy() + self._object_type = self._object_type_initial + + if use_projection_scheme: + self._exit_waves = [None] * self._num_tilts + else: + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * self._num_tilts + else: + self._exit_waves = None + @property def _probe(self): """Dummy property to return average probe""" @@ -1759,45 +1845,6 @@ def _adjoint( return current_object, current_probe - def _reset_reconstruction( - self, - store_iterations, - reset, - ): - """ """ - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - # reset can be True, False, or None (default) - if reset is True: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._object_type = self._object_type_initial - self._exit_waves = None - - # delete positions affine transform - if hasattr(self, "_tf"): - del self._tf - - elif reset is None: - # continued run - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - - # first start - else: - self.error_iterations = [] - self._exit_waves = None - class Object2p5DProbeMethodsMixin: """ From daea57354582f43a9ddd2498c93cc489a63f3344 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 19:24:21 -0800 Subject: [PATCH 032/128] moved show_transmitted probe Former-commit-id: 4fcc778926e396ac760cf365620fae1daca4e282 --- ...tive_mixedstate_multislice_ptychography.py | 75 +---------- .../iterative_multislice_ptychography.py | 75 +---------- .../phase/iterative_ptychographic_methods.py | 121 ++++++++++++++++++ 3 files changed, 123 insertions(+), 148 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index e263b2dc2..029724dbf 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -10,7 +10,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: import cupy as cp @@ -1944,79 +1944,6 @@ def visualize( ) return self - def show_transmitted_probe( - self, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations=False, - **kwargs, - ): - """ - Plots the min, max, and mean transmitted probe after propagation and transmission. - - Parameters - ---------- - plot_fourier_probe: boolean, optional - If True, the transmitted probes are also plotted in Fourier space - kwargs: - Passed to show_complex - """ - - xp = self._xp - asnumpy = self._asnumpy - - transmitted_probe_intensities = xp.sum( - xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) - ) - min_intensity_transmitted = self._transmitted_probes[ - xp.argmin(transmitted_probe_intensities), 0 - ] - max_intensity_transmitted = self._transmitted_probes[ - xp.argmax(transmitted_probe_intensities), 0 - ] - mean_transmitted = self._transmitted_probes[:, 0].mean(0) - probes = [ - asnumpy(self._return_centered_probe(probe)) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", - ] - - if plot_fourier_probe: - bottom_row = [ - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - probes = [probes, bottom_row] - - title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", - ] - - title = kwargs.get("title", title) - show_complex( - probes, - title=title, - **kwargs, - ) - def _return_self_consistency_errors( self, max_batch_size=None, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a16ab2db5..af55318be 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -10,7 +10,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: import cupy as cp @@ -1923,76 +1923,3 @@ def visualize( **kwargs, ) return self - - def show_transmitted_probe( - self, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations=False, - **kwargs, - ): - """ - Plots the min, max, and mean transmitted probe after propagation and transmission. - - Parameters - ---------- - plot_fourier_probe: boolean, optional - If True, the transmitted probes are also plotted in Fourier space - kwargs: - Passed to show_complex - """ - - xp = self._xp - asnumpy = self._asnumpy - - transmitted_probe_intensities = xp.sum( - xp.abs(self._transmitted_probes) ** 2, axis=(-2, -1) - ) - min_intensity_transmitted = self._transmitted_probes[ - xp.argmin(transmitted_probe_intensities) - ] - max_intensity_transmitted = self._transmitted_probes[ - xp.argmax(transmitted_probe_intensities) - ] - mean_transmitted = self._transmitted_probes.mean(0) - probes = [ - asnumpy(self._return_centered_probe(probe)) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", - ] - - if plot_fourier_probe: - bottom_row = [ - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - probes = [probes, bottom_row] - - title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", - ] - - title = kwargs.get("title", title) - show_complex( - probes, - title=title, - **kwargs, - ) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index f58507cf2..e50258855 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -9,6 +9,7 @@ AffineTransform, ComplexProbe, fft_shift, + generate_batches, rotate_point, spatial_frequencies, ) @@ -2106,6 +2107,120 @@ def _projection_sets_adjoint( return current_object, current_probe + def show_transmitted_probe( + self, + max_batch_size=None, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + max_batch_size: int, optional + Max number of probes to calculate at once + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + positions_px = self._positions_px.copy() + + mean_transmitted = xp.zeros_like(self._probe) + intensities_compare = [np.inf, 0] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + + # store relevant arrays + mean_transmitted += overlap.sum(0) + + intensities = xp.sum(xp.abs(overlap) ** 2, axis=(-2, -1)) + min_intensity = intensities.min() + max_intensity = intensities.max() + + if min_intensity < intensities_compare[0]: + min_intensity_transmitted = overlap[xp.argmin(intensities)] + intensities_compare[0] = min_intensity + + if max_intensity > intensities_compare[1]: + max_intensity_transmitted = overlap[xp.argmax(intensities)] + intensities_compare[1] = max_intensity + + mean_transmitted /= self._num_diffraction_patterns + + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + ticks = kwargs.get("ticks", False) + axsize = kwargs.get("axsize", (4.5, 4.5)) + + show_complex( + probes, + title=title, + ticks=ticks, + axsize=axsize, + **kwargs, + ) + class ObjectNDProbeMixedMethodsMixin: """ @@ -2726,3 +2841,9 @@ def _projection_sets_adjoint( ) return current_object, current_probe + + def show_transmitted_probe( + self, + **kwargs, + ): + raise NotImplementedError() From 90fc82da2d897531c8b22ffd733bbacb5f75bc8b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 19:53:56 -0800 Subject: [PATCH 033/128] cleaned up self-consistency viz Former-commit-id: 38205165be8ac7ad0c737be117c688dc23ede033 --- .../process/phase/iterative_base_class.py | 132 +++++++----------- .../phase/iterative_ptychographic_methods.py | 74 +++++++++- 2 files changed, 115 insertions(+), 91 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 682cb87ff..8ebaa1290 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -20,11 +20,7 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin -from py4DSTEM.process.phase.utils import ( - AffineTransform, - generate_batches, - polar_aliases, -) +from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -2169,52 +2165,6 @@ def plot_position_correction( ax.set_aspect("equal") ax.set_title("Probe positions correction") - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - def show_uncertainty_visualization( self, errors=None, @@ -2227,8 +2177,13 @@ def show_uncertainty_visualization( ): """Plot uncertainty visualization using self-consistency errors""" + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + if errors is None: errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + errors_xp = xp.asarray(errors) if projected_cropped_potential is None: projected_cropped_potential = self._return_projected_cropped_potential() @@ -2236,10 +2191,6 @@ def show_uncertainty_visualization( if kde_sigma is None: kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] - xp = self._xp - asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter - ## Kernel Density Estimation # rotated basis @@ -2270,46 +2221,57 @@ def show_uncertainty_visualization( dy = ya - yF # resampling - inds_1D = xp.ravel_multi_index( - xp.hstack( - [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - ), - pixel_output, - mode=["wrap", "wrap"], - ) + all_inds = [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] - weights = xp.hstack( - ( - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), + all_weights = [ + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ] + + pix_count = xp.zeros(pixel_size, dtype=xp.float32) + pix_output = xp.zeros(pixel_size, dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + inds_1D = xp.ravel_multi_index( + inds, + pixel_output, + mode=["wrap", "wrap"], ) - ) + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=pixel_size, + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * errors_xp, + minlength=pixel_size, + ) + + # reshape 1D arrays to 2D pix_count = xp.reshape( - xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + pix_count, + pixel_output, ) - pix_output = xp.reshape( - xp.bincount( - inds_1D, - weights=weights * xp.tile(xp.asarray(errors), 4), - minlength=pixel_size, - ), + pix_output, pixel_output, ) # kernel density estimate - pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") - pix_count[pix_count == 0.0] = np.inf - pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") - pix_output /= pix_count + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] pix_output, _, _ = return_scaled_histogram_ordering( pix_output.get(), normalize=True diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index e50258855..60b2fec40 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1434,6 +1434,12 @@ def _overlap_projection(self, current_object, current_probe): return shifted_probes, object_patches, overlap + def _return_farfield_amplitudes(self, fourier_overlap): + """Small utility to de-duplicate mixed-state Fourier projection.""" + + xp = self._xp + return xp.abs(fourier_overlap) + def _gradient_descent_fourier_projection(self, amplitudes, overlap): """ Ptychographic fourier projection method for GD method. @@ -1455,7 +1461,8 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): xp = self._xp fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) modified_overlap = xp.fft.ifft2(fourier_modified_overlap) @@ -1510,7 +1517,8 @@ def _projection_sets_fourier_projection( exit_waves = overlap.copy() fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) factor_to_be_projected = projection_c * overlap + projection_y * exit_waves fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) @@ -1846,6 +1854,54 @@ def _adjoint( return current_object, current_probe + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - farfield_amplitudes) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + class Object2p5DProbeMethodsMixin: """ @@ -2266,6 +2322,12 @@ def _overlap_projection(self, current_object, current_probe): return shifted_probes, object_patches, overlap + def _return_farfield_amplitudes(self, fourier_overlap): + """Small utility to de-duplicate mixed-state Fourier projection.""" + + xp = self._xp + return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + def _gradient_descent_fourier_projection(self, amplitudes, overlap): """ Ptychographic fourier projection method for GD method. @@ -2287,11 +2349,11 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): xp = self._xp fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap modified_overlap = xp.fft.ifft2(fourier_modified_overlap) From 1541c7eadc51da5005db4604af7ee08404fc8e08 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 19:56:18 -0800 Subject: [PATCH 034/128] oops, forgot to delete duplicate code Former-commit-id: 2d1d4e43268d77719b0b6fc1621bbab2779ac716 --- ...tive_mixedstate_multislice_ptychography.py | 47 ------------------- .../iterative_mixedstate_ptychography.py | 47 ------------------- 2 files changed, 94 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 029724dbf..4e3bd2cf0 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -1943,50 +1943,3 @@ def visualize( **kwargs, ) return self - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 911961bf0..c87d5b8e1 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1606,50 +1606,3 @@ def visualize( ) return self - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) From ad911dc21e2d269885219b1826fc86e168f5a1cd Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 22:00:02 -0800 Subject: [PATCH 035/128] cleaned up position correction - is mixed state correct? Former-commit-id: 35516980da08673517d60cc7ca4542fcb554bac9 --- .../process/phase/iterative_base_class.py | 127 ------------- ...tive_mixedstate_multislice_ptychography.py | 176 +----------------- .../iterative_mixedstate_ptychography.py | 12 +- .../iterative_multislice_ptychography.py | 174 +---------------- .../phase/iterative_overlap_tomography.py | 171 +---------------- .../phase/iterative_ptychographic_methods.py | 136 ++++++++++++-- .../iterative_singleslice_ptychography.py | 9 +- 7 files changed, 147 insertions(+), 658 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 8ebaa1290..68d769816 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1981,133 +1981,6 @@ def _report_reconstruction_summary( ) ) - def _position_correction( - self, - relevant_object, - relevant_probes, - relevant_overlap, - relevant_amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - relevant_object: np.ndarray - Current object estimate - relevant_probes:np.ndarray - fractionally-shifted probes - relevant_overlap: np.ndarray - object * probe overlap - relevant_amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * relevant_object) - else: - complex_object = relevant_object - - obj_rolled_x_patches = complex_object[ - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - obj_rolled_y_patches = complex_object[ - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - overlap_fft = xp.fft.fft2(relevant_overlap) - - exit_waves_dx_fft = overlap_fft - xp.fft.fft2( - obj_rolled_x_patches * relevant_probes - ) - exit_waves_dy_fft = overlap_fft - xp.fft.fft2( - obj_rolled_y_patches * relevant_probes - ) - - overlap_fft_conj = xp.conj(overlap_fft) - estimated_intensity = xp.abs(overlap_fft) ** 2 - measured_intensity = relevant_amplitudes**2 - - flat_shape = (relevant_overlap.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * overlap_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * overlap_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - return current_positions - def plot_position_correction( self, scale_arrows=1, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 4e3bd2cf0..2101eec54 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -644,172 +644,6 @@ def preprocess( return self - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - def _constraints( self, current_object, @@ -1041,7 +875,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, + max_position_update_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -1123,6 +957,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + max_position_update_distance: float, optional + Maximum allowed distance for update in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1309,12 +1145,12 @@ def reconstruct( if a0 >= fix_positions_iter: positions_px[start:end] = self._position_correction( self._object, - self._probe[0], - overlap[:, 0], + self._probe, + overlap, amplitudes, self._positions_px, positions_step_size, - constrain_position_distance, + max_position_update_distance, ) error += batch_error diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index c87d5b8e1..0de0227c3 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -743,7 +743,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, - constrain_position_distance: float = None, + max_position_update_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -819,8 +819,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original field of view in A + max_position_update_distance: float, optional + Maximum allowed distance for update in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -992,12 +992,12 @@ def reconstruct( if a0 >= fix_positions_iter: positions_px[start:end] = self._position_correction( self._object, - shifted_probes[:, 0], - overlap[:, 0], + shifted_probes, + overlap, amplitudes, self._positions_px, positions_step_size, - constrain_position_distance, + max_position_update_distance, ) error += batch_error diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index af55318be..95f7bd4a4 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -619,172 +619,6 @@ def preprocess( return self - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - def _constraints( self, current_object, @@ -1025,7 +859,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, + max_position_update_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -1107,8 +941,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original field of view in A + max_position_update_distance: float, optional + Maximum allowed distance for update in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1300,7 +1134,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, - constrain_position_distance, + max_position_update_distance, ) error += batch_error diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 782739d74..c4c0c1f1a 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -695,168 +695,6 @@ def preprocess( return self - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - complex_object = xp.exp(1j * current_object) - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - def _constraints( self, current_object, @@ -1055,7 +893,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, + max_position_update_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -1130,9 +968,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A + max_position_update_distance: float, optional + Maximum allowed distance for update in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1347,7 +1184,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, - constrain_position_distance, + max_position_update_distance, ) tilt_error += batch_error diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 60b2fec40..3b0143b79 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1396,7 +1396,14 @@ class ObjectNDProbeMethodsMixin: Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. """ - def _overlap_projection(self, current_object, current_probe): + def _return_shifted_probes(self, current_probe): + """Simple utlity to de-duplicate _overlap_projection""" + + xp = self._xp + shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) + return shifted_probes + + def _overlap_projection(self, current_object, shifted_probes): """ Ptychographic overlap projection method. @@ -1419,8 +1426,6 @@ def _overlap_projection(self, current_object, current_probe): xp = self._xp - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - if self._object_type == "potential": complex_object = xp.exp(1j * current_object) else: @@ -1581,8 +1586,9 @@ def _forward( Reconstruction error """ + shifted_probes = self._return_shifted_probes(current_probe) shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe + current_object, shifted_probes ) if use_projection_scheme: @@ -1854,6 +1860,110 @@ def _adjoint( return current_object, current_probe + def _position_correction( + self, + current_object, + shifted_probes, + overlap, + amplitudes, + current_positions, + positions_step_size, + max_position_update_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + shifted_probes:np.ndarray + fractionally-shifted probes + overlap: np.ndarray + object * probe overlap + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + max_position_update_distance: float + Maximum allowed distance for update in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # unperturbed + overlap_fft = xp.fft.fft2(overlap) + overlap_fft_conj = xp.conj(overlap_fft) + estimated_intensity = self._return_farfield_amplitudes(overlap_fft) ** 2 + measured_intensity = amplitudes**2 + + # book-keeping + flat_shape = (measured_intensity.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + vectorized_patch_indices_row = self._vectorized_patch_indices_row.copy() + vectorized_patch_indices_col = self._vectorized_patch_indices_col.copy() + + # dx overlap projection perturbation + self._vectorized_patch_indices_row = ( + vectorized_patch_indices_row + 1 + ) % self._object_shape[0] + _, _, overlap_dx = self._overlap_projection(current_object, shifted_probes) + self._vectorized_patch_indices_row = vectorized_patch_indices_row.copy() + + # dy overlap projection perturbation + self._vectorized_patch_indices_col = ( + vectorized_patch_indices_col + 1 + ) % self._object_shape[1] + _, _, overlap_dy = self._overlap_projection(current_object, shifted_probes) + self._vectorized_patch_indices_col = vectorized_patch_indices_col.copy() + + # partial intensities + overlap_dx_fft = overlap_fft - xp.fft.fft2(overlap_dx) + overlap_dy_fft = overlap_fft - xp.fft.fft2(overlap_dy) + partial_intensity_dx = 2 * xp.real(overlap_dx_fft * overlap_fft_conj) + partial_intensity_dy = 2 * xp.real(overlap_dy_fft * overlap_fft_conj) + + # handle mixed-state, is this correct? + if partial_intensity_dx.ndim == 4: + partial_intensity_dx = partial_intensity_dx.sum(1) + partial_intensity_dy = partial_intensity_dy.sum(1) + + partial_intensity_dx = partial_intensity_dx.reshape(flat_shape) + partial_intensity_dy = partial_intensity_dy.reshape(flat_shape) + + # least-squares fit + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + positions_update = positions_update[..., 0] * positions_step_size + + if max_position_update_distance is not None: + max_position_update_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + update_norms = xp.linalg.norm(positions_update, axis=1) + outlier_ind = update_norms > max_position_update_distance + positions_update[outlier_ind] /= ( + update_norms[outlier_ind, None] / max_position_update_distance + ) + + current_positions -= positions_update + return current_positions + def _return_self_consistency_errors( self, max_batch_size=None, @@ -1887,7 +1997,8 @@ def _return_self_consistency_errors( amplitudes = self._amplitudes[start:end] # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) + shifted_probes = self._return_shifted_probes(self._probe) + _, _, overlap = self._overlap_projection(self._object, shifted_probes) fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) @@ -1909,7 +2020,7 @@ class Object2p5DProbeMethodsMixin: Overwrites ObjectNDProbeMethodsMixin. """ - def _overlap_projection(self, current_object, current_probe): + def _overlap_projection(self, current_object, shifted_probes_in): """ Ptychographic overlap projection method. @@ -1944,7 +2055,7 @@ def _overlap_projection(self, current_object, current_probe): ] shifted_probes = xp.empty_like(object_patches) - shifted_probes[0] = fft_shift(current_probe, self._positions_px_fractional, xp) + shifted_probes[0] = shifted_probes_in for s in range(self._num_slices): # transmit @@ -2210,7 +2321,8 @@ def show_transmitted_probe( ) = self._extract_vectorized_patch_indices() # overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) + shifted_probes = self._return_shifted_probes(self._probe) + _, _, overlap = self._overlap_projection(self._object, shifted_probes) # store relevant arrays mean_transmitted += overlap.sum(0) @@ -2284,7 +2396,7 @@ class ObjectNDProbeMixedMethodsMixin: Overwrites ObjectNDProbeMethodsMixin. """ - def _overlap_projection(self, current_object, current_probe): + def _overlap_projection(self, current_object, shifted_probes): """ Ptychographic overlap projection method. @@ -2307,8 +2419,6 @@ def _overlap_projection(self, current_object, current_probe): xp = self._xp - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - if self._object_type == "potential": complex_object = xp.exp(1j * current_object) else: @@ -2620,7 +2730,7 @@ class Object2p5DProbeMixedMethodsMixin: Overwrites ObjectNDProbeMethodsMixin and ObjectNDProbeMixedMethodsMixin. """ - def _overlap_projection(self, current_object, current_probe): + def _overlap_projection(self, current_object, shifted_probes_in): """ Ptychographic overlap projection method. @@ -2665,7 +2775,7 @@ def _overlap_projection(self, current_object, current_probe): ) shifted_probes = xp.empty(shifted_shape, dtype=object_patches.dtype) - shifted_probes[0] = fft_shift(current_probe, self._positions_px_fractional, xp) + shifted_probes[0] = shifted_probes_in for s in range(self._num_slices): # transmit diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 02330638f..e80b42110 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -743,7 +743,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, + max_position_update_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -820,9 +820,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A + max_position_update_distance: float, optional + Maximum allowed distance for update in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -999,7 +998,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, - constrain_position_distance, + max_position_update_distance, ) error += batch_error From a7b9dbb10500dadc5408f5bd158381a8e7c27a60 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 22:06:08 -0800 Subject: [PATCH 036/128] more natural location for 3D function Former-commit-id: e4d9c94a1588d9611bb5ee02c3992a0b65774c7d --- .../phase/iterative_overlap_tomography.py | 20 ------------------- .../phase/iterative_ptychographic_methods.py | 7 +++++++ 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index c4c0c1f1a..248b510eb 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -1977,23 +1977,3 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 3b0143b79..e08b72a8c 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -935,6 +935,13 @@ def show_object_fft( **kwargs, ) + def _return_self_consistency_errors( + self, + **kwargs, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + @property def object_supersliced(self): """Returns super-sliced object""" From 3cbec59f1224c9bf7da0becb6fa03a89fbca8f9e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 10:32:18 -0800 Subject: [PATCH 037/128] added and cleaned up total position update functionality back in Former-commit-id: 1d75a733ef3d3c7042e9152a8f9bd9da768bbd29 --- .../iterative_mixedstate_multislice_ptychography.py | 6 ++++++ .../phase/iterative_mixedstate_ptychography.py | 6 ++++++ .../phase/iterative_multislice_ptychography.py | 6 ++++++ .../process/phase/iterative_overlap_tomography.py | 5 +++++ .../phase/iterative_ptychographic_methods.py | 13 +++++++++++++ .../phase/iterative_singleslice_ptychography.py | 6 ++++++ 6 files changed, 42 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2101eec54..bc0f2e0dc 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -876,6 +876,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, max_position_update_distance: float = None, + max_position_total_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -959,6 +960,8 @@ def reconstruct( Number of iterations to run with fixed positions before updating positions estimate max_position_update_distance: float, optional Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1094,6 +1097,7 @@ def reconstruct( ) positions_px = self._positions_px.copy()[shuffled_indices] + positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size @@ -1149,8 +1153,10 @@ def reconstruct( overlap, amplitudes, self._positions_px, + positions_px_initial[start:end], positions_step_size, max_position_update_distance, + max_position_total_distance, ) error += batch_error diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 0de0227c3..e6c4b34ab 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -744,6 +744,7 @@ def reconstruct( fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, max_position_update_distance: float = None, + max_position_total_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -821,6 +822,8 @@ def reconstruct( Number of iterations to run with fixed positions before updating positions estimate max_position_update_distance: float, optional Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -941,6 +944,7 @@ def reconstruct( ) positions_px = self._positions_px.copy()[shuffled_indices] + positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size @@ -996,8 +1000,10 @@ def reconstruct( overlap, amplitudes, self._positions_px, + positions_px_initial[start:end], positions_step_size, max_position_update_distance, + max_position_total_distance, ) error += batch_error diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 95f7bd4a4..91a681822 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -860,6 +860,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, max_position_update_distance: float = None, + max_position_total_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -943,6 +944,8 @@ def reconstruct( Number of iterations to run with fixed positions before updating positions estimate max_position_update_distance: float, optional Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1078,6 +1081,7 @@ def reconstruct( ) positions_px = self._positions_px.copy()[shuffled_indices] + positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size @@ -1133,8 +1137,10 @@ def reconstruct( overlap, amplitudes, self._positions_px, + positions_px_initial[start:end], positions_step_size, max_position_update_distance, + max_position_total_distance, ) error += batch_error diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 248b510eb..0fa1d730c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -894,6 +894,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, max_position_update_distance: float = None, + max_position_total_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -970,6 +971,8 @@ def reconstruct( Number of iterations to run with fixed positions before updating positions estimate max_position_update_distance: float, optional Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1183,8 +1186,10 @@ def reconstruct( overlap, amplitudes, self._positions_px, + self._positions_px_initial, positions_step_size, max_position_update_distance, + max_position_total_distance, ) tilt_error += batch_error diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index e08b72a8c..073516cfd 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1874,8 +1874,10 @@ def _position_correction( overlap, amplitudes, current_positions, + current_positions_initial, positions_step_size, max_position_update_distance, + max_position_total_distance, ): """ Position correction using estimated intensity gradient. @@ -1896,6 +1898,8 @@ def _position_correction( Positions step size max_position_update_distance: float Maximum allowed distance for update in A + max_position_total_distance: float + Maximum allowed distance from initial probe positions Returns -------- @@ -1968,6 +1972,15 @@ def _position_correction( update_norms[outlier_ind, None] / max_position_update_distance ) + if max_position_total_distance is not None: + max_position_total_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + deltas = current_positions - positions_update - current_positions_initial + dsts = xp.linalg.norm(deltas, axis=1) + outlier_ind = dsts > max_position_total_distance + positions_update[outlier_ind] = 0 + current_positions -= positions_update return current_positions diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index e80b42110..8fb4a829f 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -744,6 +744,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, max_position_update_distance: float = None, + max_position_total_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, @@ -822,6 +823,8 @@ def reconstruct( Number of iterations to run with fixed positions before updating positions estimate max_position_update_distance: float, optional Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -942,6 +945,7 @@ def reconstruct( ) positions_px = self._positions_px.copy()[shuffled_indices] + positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size @@ -997,8 +1001,10 @@ def reconstruct( overlap, amplitudes, self._positions_px, + positions_px_initial[start:end], positions_step_size, max_position_update_distance, + max_position_total_distance, ) error += batch_error From 1fcc8a025073e82629b4451b01ec1b70452ffa3b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 11:50:51 -0800 Subject: [PATCH 038/128] modernized probe position correction viz Former-commit-id: b323e4c3591bf7ce583a2eb5d568b57916463f35 --- .../process/phase/iterative_base_class.py | 82 +++++++++++++++---- 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 68d769816..8280dc4a4 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize import return_scaled_histogram_ordering, show_complex from scipy.ndimage import zoom @@ -1981,10 +1981,12 @@ def _report_reconstruction_summary( ) ) - def plot_position_correction( + def show_updated_positions( self, scale_arrows=1, - plot_arrow_freq=1, + plot_arrow_freq=None, + plot_cropped_rotated_fov=True, + cbar=True, verbose=True, **kwargs, ): @@ -1995,48 +1997,92 @@ def plot_position_correction( ---------- scale_arrows: float, optional scaling factor to be applied on vectors prior to plt.quiver call + plot_arrow_freq: int, optional + thinning parameter to only plot a subset of probe positions + assumes grid position verbose: bool, optional if True, prints AffineTransformation if positions have been updated """ + if verbose: if hasattr(self, "_tf"): print(self._tf) asnumpy = self._asnumpy + initial_pos = asnumpy(self._positions_initial) + pos = self.positions + + if plot_cropped_rotated_fov: + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + initial_pos = tf(initial_pos, origin=np.mean(pos, axis=0)) + pos = tf(pos, origin=np.mean(pos, axis=0)) + + obj_shape = self.object_cropped.shape[-2:] + initial_pos_com = np.mean(initial_pos, axis=0) + center_shift = initial_pos_com - ( + np.array(obj_shape) / 2 * np.array(self.sampling) + ) + initial_pos -= center_shift + pos -= center_shift + + else: + obj_shape = self._object_shape + + if plot_arrow_freq is not None: + rshape = self._datacube.Rshape + (2,) + freq = plot_arrow_freq + + initial_pos = initial_pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + pos = pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + + deltas = pos - initial_pos + norms = np.linalg.norm(deltas, axis=1) + extent = [ 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], + self.sampling[1] * obj_shape[1], + self.sampling[0] * obj_shape[0], 0, ] - initial_pos = asnumpy(self._positions_initial) - pos = self.positions - figsize = kwargs.pop("figsize", (6, 6)) - color = kwargs.pop("color", (1, 0, 0, 1)) + cmap = kwargs.pop("cmap", "Reds") fig, ax = plt.subplots(figsize=figsize) - ax.quiver( - initial_pos[::plot_arrow_freq, 1], - initial_pos[::plot_arrow_freq, 0], - (pos[::plot_arrow_freq, 1] - initial_pos[::plot_arrow_freq, 1]) - * scale_arrows, - (pos[::plot_arrow_freq, 0] - initial_pos[::plot_arrow_freq, 0]) - * scale_arrows, + + im = ax.quiver( + initial_pos[:, 1], + initial_pos[:, 0], + deltas[:, 1] * scale_arrows, + deltas[:, 0] * scale_arrows, + norms, scale_units="xy", scale=1, - color=color, + cmap=cmap, **kwargs, ) + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + cb = fig.colorbar(im, cax=ax_cb) + cb.set_label("Δ [A]", rotation=0, ha="left", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") ax.set_xlim((extent[0], extent[1])) ax.set_ylim((extent[2], extent[3])) ax.set_aspect("equal") - ax.set_title("Probe positions correction") + ax.set_title("Updated probe positions") def show_uncertainty_visualization( self, From 27886977c42d76ee2ddaf2cda2d9f857f68ae6d3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 12:03:05 -0800 Subject: [PATCH 039/128] starting viz overhaul Former-commit-id: 1275dab1d4219cc04be70ac1d3ba3b18856c8d60 --- .../process/phase/iterative_base_class.py | 300 +---------------- ...tive_mixedstate_multislice_ptychography.py | 4 + .../iterative_mixedstate_ptychography.py | 4 + .../iterative_multislice_ptychography.py | 4 + .../phase/iterative_overlap_tomography.py | 4 + .../iterative_ptychographic_visualizations.py | 317 ++++++++++++++++++ .../iterative_singleslice_ptychography.py | 4 + 7 files changed, 339 insertions(+), 298 deletions(-) create mode 100644 py4DSTEM/process/phase/iterative_ptychographic_visualizations.py diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 8280dc4a4..3a066de4e 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -6,9 +6,8 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize import return_scaled_histogram_ordering, show_complex +from mpl_toolkits.axes_grid1 import ImageGrid +from py4DSTEM.visualize import show_complex from scipy.ndimage import zoom try: @@ -1981,301 +1980,6 @@ def _report_reconstruction_summary( ) ) - def show_updated_positions( - self, - scale_arrows=1, - plot_arrow_freq=None, - plot_cropped_rotated_fov=True, - cbar=True, - verbose=True, - **kwargs, - ): - """ - Function to plot changes to probe positions during ptychography reconstruciton - - Parameters - ---------- - scale_arrows: float, optional - scaling factor to be applied on vectors prior to plt.quiver call - plot_arrow_freq: int, optional - thinning parameter to only plot a subset of probe positions - assumes grid position - verbose: bool, optional - if True, prints AffineTransformation if positions have been updated - """ - - if verbose: - if hasattr(self, "_tf"): - print(self._tf) - - asnumpy = self._asnumpy - - initial_pos = asnumpy(self._positions_initial) - pos = self.positions - - if plot_cropped_rotated_fov: - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) - - tf = AffineTransform(angle=angle) - initial_pos = tf(initial_pos, origin=np.mean(pos, axis=0)) - pos = tf(pos, origin=np.mean(pos, axis=0)) - - obj_shape = self.object_cropped.shape[-2:] - initial_pos_com = np.mean(initial_pos, axis=0) - center_shift = initial_pos_com - ( - np.array(obj_shape) / 2 * np.array(self.sampling) - ) - initial_pos -= center_shift - pos -= center_shift - - else: - obj_shape = self._object_shape - - if plot_arrow_freq is not None: - rshape = self._datacube.Rshape + (2,) - freq = plot_arrow_freq - - initial_pos = initial_pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) - pos = pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) - - deltas = pos - initial_pos - norms = np.linalg.norm(deltas, axis=1) - - extent = [ - 0, - self.sampling[1] * obj_shape[1], - self.sampling[0] * obj_shape[0], - 0, - ] - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "Reds") - - fig, ax = plt.subplots(figsize=figsize) - - im = ax.quiver( - initial_pos[:, 1], - initial_pos[:, 0], - deltas[:, 1] * scale_arrows, - deltas[:, 0] * scale_arrows, - norms, - scale_units="xy", - scale=1, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - cb = fig.colorbar(im, cax=ax_cb) - cb.set_label("Δ [A]", rotation=0, ha="left", va="bottom") - cb.ax.yaxis.set_label_coords(0.5, 1.01) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_xlim((extent[0], extent[1])) - ax.set_ylim((extent[2], extent[3])) - ax.set_aspect("equal") - ax.set_title("Updated probe positions") - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - - xp = self._xp - asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter - - if errors is None: - errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) - errors_xp = xp.asarray(errors) - - if projected_cropped_potential is None: - projected_cropped_potential = self._return_projected_cropped_potential() - - if kde_sigma is None: - kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] - - ## Kernel Density Estimation - - # rotated basis - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) - - tf = AffineTransform(angle=angle) - rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) - - padding = xp.min(rotated_points, axis=0).astype("int") - - # bilinear sampling - pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( - 2 * padding - ) - pixel_size = pixel_output.prod() - - xa = rotated_points[:, 0] - ya = rotated_points[:, 1] - - # bilinear sampling - xF = xp.floor(xa).astype("int") - yF = xp.floor(ya).astype("int") - dx = xa - xF - dy = ya - yF - - # resampling - all_inds = [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - - all_weights = [ - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ] - - pix_count = xp.zeros(pixel_size, dtype=xp.float32) - pix_output = xp.zeros(pixel_size, dtype=xp.float32) - - for inds, weights in zip(all_inds, all_weights): - inds_1D = xp.ravel_multi_index( - inds, - pixel_output, - mode=["wrap", "wrap"], - ) - - pix_count += xp.bincount( - inds_1D, - weights=weights, - minlength=pixel_size, - ) - pix_output += xp.bincount( - inds_1D, - weights=weights * errors_xp, - minlength=pixel_size, - ) - - # reshape 1D arrays to 2D - pix_count = xp.reshape( - pix_count, - pixel_output, - ) - pix_output = xp.reshape( - pix_output, - pixel_output, - ) - - # kernel density estimate - pix_count = gaussian_filter(pix_count, kde_sigma) - pix_output = gaussian_filter(pix_output, kde_sigma) - sub = pix_count > 1e-3 - pix_output[sub] /= pix_count[sub] - pix_output[np.logical_not(sub)] = 1 - pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] - pix_output, _, _ = return_scaled_histogram_ordering( - pix_output.get(), normalize=True - ) - - ## Visualization - if plot_histogram: - spec = GridSpec( - ncols=1, - nrows=2, - height_ratios=[1, 4], - hspace=0.15, - ) - auto_figsize = (4, 5.25) - else: - spec = GridSpec( - ncols=1, - nrows=1, - ) - auto_figsize = (4, 4) - - figsize = kwargs.pop("figsize", auto_figsize) - - fig = plt.figure(figsize=figsize) - - if plot_histogram: - ax_hist = fig.add_subplot(spec[0]) - - counts, bins = np.histogram(errors, bins=50) - ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) - ax_hist.set_ylabel("Counts") - ax_hist.set_xlabel("Normalized Squared Error") - - ax = fig.add_subplot(spec[-1]) - - cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - - projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( - projected_cropped_potential, - vmin=vmin, - vmax=vmax, - ) - - extent = [ - 0, - self.sampling[1] * projected_cropped_potential.shape[1], - self.sampling[0] * projected_cropped_potential.shape[0], - 0, - ] - - ax.imshow( - projected_cropped_potential, - vmin=vmin, - vmax=vmax, - extent=extent, - alpha=1 - pix_output, - cmap=cmap, - **kwargs, - ) - - if plot_contours: - aligned_points = asnumpy(rotated_points - padding) - aligned_points[:, 0] *= self.sampling[0] - aligned_points[:, 1] *= self.sampling[1] - - ax.tricontour( - aligned_points[:, 1], - aligned_points[:, 0], - errors, - colors="grey", - levels=5, - # linestyles='dashed', - linewidths=0.5, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_xlim((extent[0], extent[1])) - ax.set_ylim((extent[2], extent[3])) - ax.xaxis.set_ticks_position("bottom") - - spec.tight_layout(fig) - @property def angular_sampling(self): """Angular sampling [mrad]""" diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index bc0f2e0dc..43f08a157 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -36,6 +36,9 @@ ProbeMethodsMixin, ProbeMixedMethodsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -48,6 +51,7 @@ class MixedstateMultislicePtychographicReconstruction( + VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index e6c4b34ab..accab6714 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -33,6 +33,9 @@ ProbeMethodsMixin, ProbeMixedMethodsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -45,6 +48,7 @@ class MixedstatePtychographicReconstruction( + VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 91a681822..0db30f4cb 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -33,6 +33,9 @@ ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -45,6 +48,7 @@ class MultislicePtychographicReconstruction( + VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, Object2p5DConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 0fa1d730c..1ccbf50b1 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -37,6 +37,9 @@ ProbeListMethodsMixin, ProbeMethodsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,6 +52,7 @@ class OverlapTomographicReconstruction( + VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py new file mode 100644 index 000000000..b43825e88 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -0,0 +1,317 @@ +import warnings + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.process.phase.utils import AffineTransform +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex +from scipy.ndimage import gaussian_filter, rotate + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +warnings.simplefilter(action="always", category=UserWarning) + + +class VisualizationsMixin: + """ + Mixin class for various visualization methods. + """ + + def show_updated_positions( + self, + scale_arrows=1, + plot_arrow_freq=None, + plot_cropped_rotated_fov=True, + cbar=True, + verbose=True, + **kwargs, + ): + """ + Function to plot changes to probe positions during ptychography reconstruciton + + Parameters + ---------- + scale_arrows: float, optional + scaling factor to be applied on vectors prior to plt.quiver call + plot_arrow_freq: int, optional + thinning parameter to only plot a subset of probe positions + assumes grid position + verbose: bool, optional + if True, prints AffineTransformation if positions have been updated + """ + + if verbose: + if hasattr(self, "_tf"): + print(self._tf) + + asnumpy = self._asnumpy + + initial_pos = asnumpy(self._positions_initial) + pos = self.positions + + if plot_cropped_rotated_fov: + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + initial_pos = tf(initial_pos, origin=np.mean(pos, axis=0)) + pos = tf(pos, origin=np.mean(pos, axis=0)) + + obj_shape = self.object_cropped.shape[-2:] + initial_pos_com = np.mean(initial_pos, axis=0) + center_shift = initial_pos_com - ( + np.array(obj_shape) / 2 * np.array(self.sampling) + ) + initial_pos -= center_shift + pos -= center_shift + + else: + obj_shape = self._object_shape + + if plot_arrow_freq is not None: + rshape = self._datacube.Rshape + (2,) + freq = plot_arrow_freq + + initial_pos = initial_pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + pos = pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + + deltas = pos - initial_pos + norms = np.linalg.norm(deltas, axis=1) + + extent = [ + 0, + self.sampling[1] * obj_shape[1], + self.sampling[0] * obj_shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "Reds") + + fig, ax = plt.subplots(figsize=figsize) + + im = ax.quiver( + initial_pos[:, 1], + initial_pos[:, 0], + deltas[:, 1] * scale_arrows, + deltas[:, 0] * scale_arrows, + norms, + scale_units="xy", + scale=1, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + cb = fig.colorbar(im, cax=ax_cb) + cb.set_label("Δ [A]", rotation=0, ha="left", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.set_aspect("equal") + ax.set_title("Updated probe positions") + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + errors_xp = xp.asarray(errors) + + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + all_inds = [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + + all_weights = [ + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ] + + pix_count = xp.zeros(pixel_size, dtype=xp.float32) + pix_output = xp.zeros(pixel_size, dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + inds_1D = xp.ravel_multi_index( + inds, + pixel_output, + mode=["wrap", "wrap"], + ) + + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=pixel_size, + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * errors_xp, + minlength=pixel_size, + ) + + # reshape 1D arrays to 2D + pix_count = xp.reshape( + pix_count, + pixel_output, + ) + pix_output = xp.reshape( + pix_output, + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5.25) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized Squared Error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], + 0, + ] + + ax.imshow( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 8fb4a829f..42f5956cf 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -30,6 +30,9 @@ ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -42,6 +45,7 @@ class SingleslicePtychographicReconstruction( + VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, From 10a9f4ab97cbf6376f77eaf2d6cd517439c6f9e4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 12:09:18 -0800 Subject: [PATCH 040/128] remove redundant figax viz functions Former-commit-id: 3345b0198724352c5c65652fd1520751ca2a4244 --- ...tive_mixedstate_multislice_ptychography.py | 65 --------------- .../iterative_mixedstate_ptychography.py | 61 -------------- .../iterative_multislice_ptychography.py | 65 --------------- .../phase/iterative_overlap_tomography.py | 83 ------------------- .../iterative_singleslice_ptychography.py | 61 -------------- 5 files changed, 335 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 43f08a157..863eeb35e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -1241,71 +1241,6 @@ def reconstruct( return self - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - def _visualize_last_iteration( self, fig, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index accab6714..398cd679b 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1077,67 +1077,6 @@ def reconstruct( return self - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - def _visualize_last_iteration( self, fig, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 0db30f4cb..573a16873 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1221,71 +1221,6 @@ def reconstruct( return self - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - def _visualize_last_iteration( self, fig, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 1ccbf50b1..37a2ee9b1 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -1343,89 +1343,6 @@ def reconstruct( return self - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - cmap = kwargs.pop("cmap", "magma") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - def _visualize_last_iteration( self, fig, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 42f5956cf..e94d196c9 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1077,67 +1077,6 @@ def reconstruct( return self - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - def _visualize_last_iteration( self, fig, From 8d98677937b29ec75b4d38aa50cc5cdecbdd25b3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 14:50:10 -0800 Subject: [PATCH 041/128] added visualize_last functionality Former-commit-id: 9257242b3171cff87c709a98b648a0e6a92dce73 --- ...tive_mixedstate_multislice_ptychography.py | 203 ---------------- .../iterative_mixedstate_ptychography.py | 197 --------------- .../iterative_multislice_ptychography.py | 203 ---------------- .../phase/iterative_overlap_tomography.py | 228 ------------------ .../phase/iterative_ptychographic_methods.py | 72 +++++- .../iterative_ptychographic_visualizations.py | 212 +++++++++++++++- .../iterative_singleslice_ptychography.py | 202 +--------------- 7 files changed, 274 insertions(+), 1043 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 863eeb35e..8a6e0a078 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -1241,204 +1241,6 @@ def reconstruct( return self - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual[0] - else: - probe_array = self.probe_fourier[0] - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe[0], power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - def _visualize_all_iterations( self, fig, @@ -1668,7 +1470,6 @@ def visualize( plot_fourier_probe: bool = False, remove_initial_probe_aberrations: bool = False, cbar: bool = True, - padding: int = 0, **kwargs, ): """ @@ -1691,8 +1492,6 @@ def visualize( remove_initial_probe_aberrations: bool, optional If true, when plotting fourier probe, removes initial probe to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object Returns -------- @@ -1708,7 +1507,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) else: @@ -1720,7 +1518,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) return self diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 398cd679b..c14d33c14 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1077,200 +1077,6 @@ def reconstruct( return self - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - ax = fig.add_subplot(spec[0, 1]) - - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual[0] - else: - probe_array = self.probe_fourier[0] - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe[0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - def _visualize_all_iterations( self, fig, @@ -1498,7 +1304,6 @@ def visualize( plot_fourier_probe: bool = False, remove_initial_probe_aberrations: bool = False, cbar: bool = True, - padding: int = 0, **kwargs, ): """ @@ -1538,7 +1343,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) else: @@ -1550,7 +1354,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 573a16873..47c9ca85f 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1221,204 +1221,6 @@ def reconstruct( return self - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - def _visualize_all_iterations( self, fig, @@ -1646,7 +1448,6 @@ def visualize( plot_fourier_probe: bool = False, remove_initial_probe_aberrations: bool = False, cbar: bool = True, - padding: int = 0, **kwargs, ): """ @@ -1669,8 +1470,6 @@ def visualize( remove_initial_probe_aberrations: bool, optional If true, when plotting fourier probe, removes initial probe to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object Returns -------- @@ -1686,7 +1485,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) else: @@ -1698,7 +1496,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) return self diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 37a2ee9b1..9a025c947 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -1343,222 +1343,6 @@ def reconstruct( return self - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - asnumpy = self._asnumpy - - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed object projection") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - ax_cb, - chroma_boost=chroma_boost, - ) - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed object projection") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - def _visualize_all_iterations( self, fig, @@ -1810,10 +1594,6 @@ def visualize( plot_fourier_probe: bool = False, remove_initial_probe_aberrations: bool = False, cbar: bool = True, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims=(None, None), - y_lims=(None, None), **kwargs, ): """ @@ -1859,10 +1639,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, **kwargs, ) else: @@ -1874,10 +1650,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 073516cfd..16a891fed 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -108,6 +108,8 @@ def _crop_rotate_object_fov( def _return_projected_cropped_potential( self, + return_kwargs=False, + **kwargs, ): """Utility function to accommodate multiple classes""" if self._object_type == "complex": @@ -115,7 +117,10 @@ def _return_projected_cropped_potential( else: projected_cropped_potential = self.object_cropped - return projected_cropped_potential + if return_kwargs: + return projected_cropped_potential, kwargs + else: + return projected_cropped_potential def _return_object_fft( self, @@ -358,6 +363,8 @@ def _initialize_object( def _return_projected_cropped_potential( self, + return_kwargs=False, + **kwargs, ): """Utility function to accommodate multiple classes""" if self._object_type == "complex": @@ -365,7 +372,10 @@ def _return_projected_cropped_potential( else: projected_cropped_potential = self.object_cropped.sum(0) - return projected_cropped_potential + if return_kwargs: + return projected_cropped_potential, kwargs + else: + return projected_cropped_potential def _return_object_fft( self, @@ -824,9 +834,35 @@ def _crop_rotate_object_manually( def _return_projected_cropped_potential( self, + return_kwargs=False, + **kwargs, ): """Utility function to accommodate multiple classes""" - raise NotImplementedError() + + projection_angle_deg = kwargs.pop("projection_angle_deg", None) + projection_axes = kwargs.pop("projection_axes", (0, 2)) + x_lims = kwargs.pop("x_lims", (None, None)) + y_lims = kwargs.pop("y_lims", (None, None)) + + if projection_angle_deg is not None: + obj = self._rotate( + self._object, + projection_angle_deg, + axes=projection_axes, + reshape=False, + order=2, + ) + else: + obj = self._object + + obj = self._crop_rotate_object_manually( + obj, angle=None, x_lims=x_lims, y_lims=y_lims + ).sum(0) + + if return_kwargs: + return obj, kwargs + else: + return obj def _return_object_fft( self, @@ -1177,6 +1213,13 @@ def show_fourier_probe( **kwargs, ) + def _return_single_probe(self): + """Current probe estimate""" + if not hasattr(self, "_probe"): + return None + + return self._probe + @property def probe_fourier(self): """Current probe estimate in Fourier space""" @@ -1331,6 +1374,13 @@ def show_fourier_probe( **kwargs, ) + def _return_single_probe(self): + """Current probe estimate""" + if not hasattr(self, "_probe"): + return None + + return self._probe[0] + class ProbeListMethodsMixin: """ @@ -1385,9 +1435,10 @@ def _reset_reconstruction( else: self._exit_waves = None - @property - def _probe(self): - """Dummy property to return average probe""" + def _return_single_probe(self): + """Current probe estimate""" + if not hasattr(self, "_probes_all"): + return None xp = self._xp probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) @@ -1397,6 +1448,11 @@ def _probe(self): return probe / self._num_tilts + @property + def _probe(self): + """Dummy property to make single-probe functions work""" + return self._return_single_probe() + class ObjectNDProbeMethodsMixin: """ @@ -3002,9 +3058,7 @@ def _projection_sets_adjoint( current_object[s] = object_update * probe_normalization # back-transmit - exit_waves_copy *= xp.expand_dims( - xp.conj(obj), axis=1 - ) # / xp.abs(obj) ** 2 + exit_waves_copy *= xp.expand_dims(xp.conj(obj), axis=1) if s > 0: # back-propagate diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index b43825e88..829c72730 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -5,8 +5,8 @@ from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.process.phase.utils import AffineTransform -from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex -from scipy.ndimage import gaussian_filter, rotate +from py4DSTEM.visualize import return_scaled_histogram_ordering +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: import cupy as cp @@ -21,6 +21,214 @@ class VisualizationsMixin: Mixin class for various visualization methods. """ + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + chroma_boost = kwargs.pop("chroma_boost", 1) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + # get scaled arrays + obj, kwargs = self._return_projected_cropped_potential( + return_kwargs=True, **kwargs + ) + probe = self._return_single_probe() + + obj, vmin, vmax = return_scaled_histogram_ordering(obj, vmin, vmax) + + extent = [ + 0, + self.sampling[1] * obj.shape[1], + self.sampling[0] * obj.shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + def show_updated_positions( self, scale_arrows=1, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index e94d196c9..39636b4bd 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1077,203 +1077,6 @@ def reconstruct( return self - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - def _visualize_all_iterations( self, fig, @@ -1502,7 +1305,6 @@ def visualize( plot_fourier_probe: bool = False, remove_initial_probe_aberrations: bool = False, cbar: bool = True, - padding: int = 0, **kwargs, ): """ @@ -1540,9 +1342,8 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, - cbar=cbar, remove_initial_probe_aberrations=remove_initial_probe_aberrations, - padding=padding, + cbar=cbar, **kwargs, ) else: @@ -1554,7 +1355,6 @@ def visualize( plot_fourier_probe=plot_fourier_probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, - padding=padding, **kwargs, ) From 425af405b6f30c2ece806243ba2647dccfb3ce3e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 15:50:07 -0800 Subject: [PATCH 042/128] visualize_all for single-slice Former-commit-id: f193c3cc5ac9b7b308998b717a316eb85944f490 --- .../phase/iterative_ptychographic_methods.py | 73 ++++-- .../iterative_ptychographic_visualizations.py | 237 +++++++++++++++++- .../iterative_singleslice_ptychography.py | 222 +--------------- 3 files changed, 285 insertions(+), 247 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 16a891fed..ab702df82 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -108,19 +108,23 @@ def _crop_rotate_object_fov( def _return_projected_cropped_potential( self, + obj=None, return_kwargs=False, **kwargs, ): """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped) + if obj is None: + obj = self.object_cropped else: - projected_cropped_potential = self.object_cropped + obj = self._crop_rotate_object_fov(obj) + + if np.iscomplexobj(obj): + obj = np.angle(obj) if return_kwargs: - return projected_cropped_potential, kwargs + return obj, kwargs else: - return projected_cropped_potential + return obj def _return_object_fft( self, @@ -363,19 +367,26 @@ def _initialize_object( def _return_projected_cropped_potential( self, + obj=None, return_kwargs=False, **kwargs, ): """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) + + if obj is None: + obj = self.object_cropped + else: + obj = self._crop_rotate_object_fov(obj) + + if np.iscomplexobj(obj): + obj = np.angle(obj).sum(0) else: - projected_cropped_potential = self.object_cropped.sum(0) + obj = obj.sum(0) if return_kwargs: - return projected_cropped_potential, kwargs + return obj, kwargs else: - return projected_cropped_potential + return obj def _return_object_fft( self, @@ -834,6 +845,7 @@ def _crop_rotate_object_manually( def _return_projected_cropped_potential( self, + obj=None, return_kwargs=False, **kwargs, ): @@ -844,16 +856,17 @@ def _return_projected_cropped_potential( x_lims = kwargs.pop("x_lims", (None, None)) y_lims = kwargs.pop("y_lims", (None, None)) + if obj is None: + obj = self._object + if projection_angle_deg is not None: obj = self._rotate( - self._object, + obj, projection_angle_deg, axes=projection_axes, reshape=False, order=2, ) - else: - obj = self._object obj = self._crop_rotate_object_manually( obj, angle=None, x_lims=x_lims, y_lims=y_lims @@ -1213,12 +1226,15 @@ def show_fourier_probe( **kwargs, ) - def _return_single_probe(self): + def _return_single_probe(self, probe=None): """Current probe estimate""" - if not hasattr(self, "_probe"): - return None + if probe is not None: + return probe + else: + if not hasattr(self, "_probe"): + return None - return self._probe + return self._probe @property def probe_fourier(self): @@ -1374,12 +1390,15 @@ def show_fourier_probe( **kwargs, ) - def _return_single_probe(self): + def _return_single_probe(self, probe=None): """Current probe estimate""" - if not hasattr(self, "_probe"): - return None + if probe is not None: + return probe[0] + else: + if not hasattr(self, "_probe"): + return None - return self._probe[0] + return self._probe[0] class ProbeListMethodsMixin: @@ -1435,15 +1454,19 @@ def _reset_reconstruction( else: self._exit_waves = None - def _return_single_probe(self): + def _return_single_probe(self, probe=None): """Current probe estimate""" - if not hasattr(self, "_probes_all"): - return None + if probe is not None: + _probes = probe + else: + if not hasattr(self, "_probes_all"): + return None + _probes = self._probes_all xp = self._xp probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) - for pr in self._probes_all: + for pr in _probes: probe += pr return probe / self._num_tilts diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index 829c72730..a7a262e9c 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -1,9 +1,10 @@ import warnings +from typing import Tuple import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import make_axes_locatable +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.visualize import return_scaled_histogram_ordering from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg @@ -229,6 +230,240 @@ def _visualize_last_iteration( fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + iterations_grid: Tuple[int, int], + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + num_iter = len(self.object_iterations) + + if iterations_grid == "auto": + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + **kwargs, + ) + + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + + else: + if plot_probe or plot_fourier_probe: + if iterations_grid[0] != 2: + raise ValueError() + else: + if iterations_grid[0] * iterations_grid[1] > num_iter: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + chroma_boost = kwargs.pop("chroma_boost", 1) + + # most recent errors + errors = np.array(self.error_iterations)[-num_iter:] + + max_iter = num_iter - 1 + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + probes = [ + self._return_single_probe(self.probe_iterations[idx]) + for idx in grid_range + ] + else: + total_grids = np.prod(iterations_grid) + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + + objects = [] + + for idx in grid_range: + if idx < grid_range[-1]: + obj = self._return_projected_cropped_potential( + obj=self.object_iterations[idx], + return_kwargs=False, + **kwargs, + ) + else: + obj, kwargs = self._return_projected_cropped_potential( + obj=self.object_iterations[idx], return_kwargs=True, **kwargs + ) + + objects.append(obj) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[n], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} potential") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[n], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + + else: + probe_array = Complex2RGB( + probes[n], + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + def show_updated_positions( self, scale_arrows=1, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 39636b4bd..d0106fe97 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -8,8 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: @@ -1077,225 +1076,6 @@ def reconstruct( return self - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append(self._crop_rotate_object_fov(obj, padding=padding)) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - - else: - probe_array = Complex2RGB( - probes[grid_range[n]], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - def visualize( self, fig=None, From 375e835bef60e2f6f5f7db6e1407e7687001ad6e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 16:01:12 -0800 Subject: [PATCH 043/128] visualize_all for all Former-commit-id: fbb83375dcaada925861db6f96f22291c3164739 --- ...tive_mixedstate_multislice_ptychography.py | 284 +--------------- .../iterative_mixedstate_ptychography.py | 285 +--------------- .../iterative_multislice_ptychography.py | 282 +--------------- .../phase/iterative_overlap_tomography.py | 316 +----------------- .../iterative_ptychographic_visualizations.py | 65 ++++ .../iterative_singleslice_ptychography.py | 64 ---- 6 files changed, 69 insertions(+), 1227 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 8a6e0a078..b00fea11b 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -8,8 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: @@ -1240,284 +1239,3 @@ def reconstruct( xp.clear_memo() return self - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append( - self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) - ) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]][0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - return self diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index c14d33c14..317f9316c 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -8,8 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: @@ -1076,285 +1075,3 @@ def reconstruct( xp.clear_memo() return self - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append(self._crop_rotate_object_fov(obj, padding=padding)) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]][0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - - return self diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 47c9ca85f..6bb14d91a 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -8,8 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg try: @@ -1220,282 +1219,3 @@ def reconstruct( xp.clear_memo() return self - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append( - self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) - ) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]], power=2, chroma_boost=chroma_boost - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - return self diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 9a025c947..336fcda89 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -8,10 +8,8 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg -from scipy.ndimage import rotate as rotate_np try: import cupy as cp @@ -1343,318 +1341,6 @@ def reconstruct( return self - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - if projection_angle_deg is not None: - objects = [ - self._crop_rotate_object_manually( - rotate_np( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ).sum(0), - angle=None, - x_lims=x_lims, - y_lims=y_lims, - ) - for obj in self.object_iterations - ] - else: - objects = [ - self._crop_rotate_object_manually( - obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - for obj in self.object_iterations - ] - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) if plot_probe else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} Object") - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - - return self - @property def positions(self): """Probe positions [A]""" diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index a7a262e9c..9db32d767 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -464,6 +464,71 @@ def _visualize_all_iterations( spec.tight_layout(fig) + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, + cbar: bool = True, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + return self + def show_updated_positions( self, scale_arrows=1, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index d0106fe97..d218651c7 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1075,67 +1075,3 @@ def reconstruct( xp.clear_memo() return self - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - **kwargs, - ) - - return self From 18c6be67e3bf9d292ee7a2a4ee77b1873a077a01 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 16:42:30 -0800 Subject: [PATCH 044/128] small np cp bugs Former-commit-id: 74821b656fa714ba38e88b735a48d375b9e8d98f --- .../process/phase/iterative_overlap_tomography.py | 7 ++++++- .../phase/iterative_ptychographic_methods.py | 13 +++++++++---- .../phase/iterative_ptychographic_visualizations.py | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 336fcda89..a3e80046a 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -1328,7 +1328,12 @@ def reconstruct( if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) + self.probe_iterations.append( + [ + asnumpy(self._return_centered_probe(pr.copy())) + for pr in self._probes_all + ] + ) # store result self.object = asnumpy(self._object) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index ab702df82..e45aac081 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1228,8 +1228,10 @@ def show_fourier_probe( def _return_single_probe(self, probe=None): """Current probe estimate""" + xp = self._xp + if probe is not None: - return probe + return xp.asarray(probe) else: if not hasattr(self, "_probe"): return None @@ -1392,8 +1394,10 @@ def show_fourier_probe( def _return_single_probe(self, probe=None): """Current probe estimate""" + xp = self._xp + if probe is not None: - return probe[0] + return xp.asarray(probe[0]) else: if not hasattr(self, "_probe"): return None @@ -1456,14 +1460,15 @@ def _reset_reconstruction( def _return_single_probe(self, probe=None): """Current probe estimate""" + xp = self._xp + if probe is not None: - _probes = probe + _probes = [xp.asarray(pr) for pr in probe] else: if not hasattr(self, "_probes_all"): return None _probes = self._probes_all - xp = self._xp probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) for pr in _probes: diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index 9db32d767..05d194e3f 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -431,7 +431,7 @@ def _visualize_all_iterations( else: probe_array = Complex2RGB( - probes[n], + asnumpy(probes[n]), power=2, chroma_boost=chroma_boost, ) From c784b143d13e0eaf6b890ceaedd64a64d20c89e7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 16:46:05 -0800 Subject: [PATCH 045/128] some figsize defaults Former-commit-id: 550ee7080013215f2324732b613fa74c2c0030fb --- .../process/phase/iterative_ptychographic_visualizations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index 05d194e3f..4f53e153d 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -752,13 +752,13 @@ def show_uncertainty_visualization( height_ratios=[1, 4], hspace=0.15, ) - auto_figsize = (4, 5.25) + auto_figsize = (6, 8) else: spec = GridSpec( ncols=1, nrows=1, ) - auto_figsize = (4, 4) + auto_figsize = (6, 6) figsize = kwargs.pop("figsize", auto_figsize) From 93c3e6f88b6d0d8b28a58250a915a037666a2e00 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 2 Jan 2024 17:30:47 -0800 Subject: [PATCH 046/128] histogram scaling for all Former-commit-id: b141ddb488d9700e9f09b1561446caab7874a9c5 --- .../iterative_ptychographic_visualizations.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index 4f53e153d..4208cfd96 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -308,6 +308,8 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") chroma_boost = kwargs.pop("chroma_boost", 1) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) # most recent errors errors = np.array(self.error_iterations)[-num_iter:] @@ -390,22 +392,25 @@ def _visualize_all_iterations( ) for n, ax in enumerate(grid): + obj, vmin_n, vmax_n = return_scaled_histogram_ordering( + objects[n], vmin=vmin, vmax=vmax + ) im = ax.imshow( - objects[n], + obj, extent=extent, cmap=cmap, + vmin=vmin_n, + vmax=vmax_n, **kwargs, ) ax.set_title(f"Iter: {grid_range[n]} potential") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") + if cbar: grid.cbar_axes[n].colorbar(im) if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( fig, spec[1], @@ -451,8 +456,6 @@ def _visualize_all_iterations( ) if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) if plot_probe: ax2 = fig.add_subplot(spec[2]) else: From 0e246e25cb852257e69f492c460926eec6a18a5f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 3 Jan 2024 18:09:54 -0800 Subject: [PATCH 047/128] renamed to ptycho-tomo, started magnetic refactoring Former-commit-id: e88e27f1717270756c549b6209279ea879be8c0c --- py4DSTEM/process/phase/__init__.py | 6 +- ...tive_magnetic_ptychographic_tomography.py} | 8 +- ....py => iterative_magnetic_ptychography.py} | 723 +++++++----------- .../phase/iterative_ptychographic_methods.py | 2 +- ... => iterative_ptychographic_tomography.py} | 10 +- 5 files changed, 281 insertions(+), 468 deletions(-) rename py4DSTEM/process/phase/{iterative_overlap_magnetic_tomography.py => iterative_magnetic_ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{iterative_simultaneous_ptychography.py => iterative_magnetic_ptychography.py} (86%) rename py4DSTEM/process/phase/{iterative_overlap_tomography.py => iterative_ptychographic_tomography.py} (99%) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 1005a619d..2069ffebf 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,13 +3,13 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction +from py4DSTEM.process.phase.iterative_magnetic_ptychographic_tomography import MagneticPtychographicTomographyReconstruction +from py4DSTEM.process.phase.iterative_magnetic_ptychography import MagneticPtychographicReconstruction from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction -from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_tomography import PtychographicTomographyReconstruction from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py rename to py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py index db2feaf10..816f7185b 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py @@ -1,6 +1,6 @@ """ Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely overlap magnetic tomography. +namely magnetic ptychographic tomography. """ import warnings @@ -45,7 +45,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class OverlapMagneticTomographicReconstruction( +class MagneticPtychographicTomographyReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, @@ -56,7 +56,7 @@ class OverlapMagneticTomographicReconstruction( PtychographicReconstruction, ): """ - Overlap Magnetic Tomographic Reconstruction Class. + Magnetic Ptychographic Tomography Reconstruction Class. List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) Reconstructed probe dimensions : (Sx,Sy) @@ -140,7 +140,7 @@ def __init__( initial_scan_positions: Sequence[np.ndarray] = None, verbose: bool = True, device: str = "cpu", - name: str = "overlap-magnetic-tomographic_reconstruction", + name: str = "magnetic-ptychographic-tomography_reconstruction", **kwargs, ): Custom.__init__(self, name=name) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py similarity index 86% rename from py4DSTEM/process/phase/iterative_simultaneous_ptychography.py rename to py4DSTEM/process/phase/iterative_magnetic_ptychography.py index aafd58134..545a2c5e7 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -1,6 +1,6 @@ """ Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely joint ptychography. +namely magnetic ptychography. """ import warnings @@ -27,8 +27,13 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeListMethodsMixin, ProbeMethodsMixin, ) +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -36,21 +41,23 @@ polar_aliases, polar_symbols, ) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) -class SimultaneousPtychographicReconstruction( +class MagneticPtychographicReconstruction( + VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ObjectNDProbeMethodsMixin, + ProbeListMethodsMixin, ProbeMethodsMixin, ObjectNDMethodsMixin, PtychographicReconstruction, ): """ - Iterative Simultaneous Ptychographic Reconstruction Class. + Iterative Magnetic Ptychographic Reconstruction Class. Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) Reconstructed probe dimensions : (Sx,Sy) @@ -66,8 +73,8 @@ class SimultaneousPtychographicReconstruction( Tuple of input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - simultaneous_measurements_mode: str, optional - One of '-+', '-0+', '0+', where -/0/+ refer to the sign of the magnetic potential + magnetic_contribution_sign: str, optional + One of '-+', '-0+', '0+' semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -85,8 +92,8 @@ class SimultaneousPtychographicReconstruction( positions_mask: np.ndarray, optional Boolean real space mask to select positions in datacube to skip for reconstruction initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j + Initial guess for complex-valued object of dimensions (2,Px,Py) + If None, initialized to 1.0j for complex objects and 0.0 for potential objects initial_probe_guess: np.ndarray, optional Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations @@ -107,13 +114,13 @@ class SimultaneousPtychographicReconstruction( """ # Class-specific Metadata - _class_specific_metadata = ("_simultaneous_measurements_mode",) + _class_specific_metadata = ("_magnetic_contribution_sign",) def __init__( self, energy: float, datacube: Sequence[DataCube] = None, - simultaneous_measurements_mode: str = "-+", + magnetic_contribution_sign: str = "-+", semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, @@ -127,7 +134,7 @@ def __init__( object_type: str = "complex", verbose: bool = True, device: str = "cpu", - name: str = "simultaneous_ptychographic_reconstruction", + name: str = "magnetic_ptychographic_reconstruction", **kwargs, ): Custom.__init__(self, name=name) @@ -175,7 +182,7 @@ def __init__( # Data self._datacube = datacube self._object = initial_object_guess - self._probe = initial_probe_guess + self._probe_init = initial_probe_guess # Common Metadata self._vacuum_probe_intensity = vacuum_probe_intensity @@ -192,7 +199,7 @@ def __init__( self._preprocessed = False # Class-specific Metadata - self._simultaneous_measurements_mode = simultaneous_measurements_mode + self._magnetic_contribution_sign = magnetic_contribution_sign def preprocess( self, @@ -203,7 +210,7 @@ def preprocess( fit_function: str = "plane", plot_rotation: bool = True, maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_probe_overlaps: bool = True, force_com_rotation: float = None, force_com_transpose: float = None, @@ -211,6 +218,7 @@ def preprocess( force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, + progress_bar: bool = True, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, **kwargs, @@ -292,65 +300,49 @@ def preprocess( ) ) - if self._simultaneous_measurements_mode == "-+": - self._sim_recon_mode = 0 - self._num_sim_measurements = 2 - if self._verbose: - print( - ( - "Magnetic vector potential sign in first meaurement assumed to be negative.\n" - "Magnetic vector potential sign in second meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 2: - raise ValueError( - f"datacube must be a set of two measurements, not length {len(self._datacube)}." - ) - if self._datacube[0].shape != self._datacube[1].shape: - raise ValueError("datacube intensities must be the same size.") - elif self._simultaneous_measurements_mode == "-0+": - self._sim_recon_mode = 1 - self._num_sim_measurements = 3 - if self._verbose: - print( - ( - "Magnetic vector potential sign in first meaurement assumed to be negative.\n" - "Magnetic vector potential assumed to be zero in second meaurement.\n" - "Magnetic vector potential sign in third meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 3: - raise ValueError( - f"datacube must be a set of three measurements, not length {len(self._datacube)}." - ) - if ( - self._datacube[0].shape != self._datacube[1].shape - or self._datacube[0].shape != self._datacube[2].shape - ): - raise ValueError("datacube intensities must be the same size.") - elif self._simultaneous_measurements_mode == "0+": - self._sim_recon_mode = 2 - self._num_sim_measurements = 2 - if self._verbose: - print( - ( - "Magnetic vector potential assumed to be zero in first meaurement.\n" - "Magnetic vector potential sign in second meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 2: - raise ValueError( - f"datacube must be a set of two measurements, not length {len(self._datacube)}." - ) - if self._datacube[0].shape != self._datacube[1].shape: - raise ValueError("datacube intensities must be the same size.") + if self._magnetic_contribution_sign == "-+": + self._recon_mode = 0 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic vector potential sign in first meaurement assumed to be negative.\n" + "Magnetic vector potential sign in second meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "-0+": + self._recon_mode = 1 + self._num_measurements = 3 + magnetic_contribution_msg = ( + "Magnetic vector potential sign in first meaurement assumed to be negative.\n" + "Magnetic vector potential assumed to be zero in second meaurement.\n" + "Magnetic vector potential sign in third meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "0+": + self._recon_mode = 2 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic vector potential assumed to be zero in first meaurement.\n" + "Magnetic vector potential sign in second meaurement assumed to be positive." + ) else: raise ValueError( - f"simultaneous_measurements_mode must be either '-+', '-0+', or '0+', not {self._simultaneous_measurements_mode}" + f"magnetic_contribution_sign must be either '-+', '-0+', or '0+', not {self._magnetic_contribution_sign}" ) + if self._verbose: + print(magnetic_contribution_msg) + + if len(self._datacube) != self._num_measurements: + raise ValueError( + f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." + ) + + dc_shapes = [dc.shape for dc in self._datacube] + if dc_shapes.count(dc_shapes[0]) != self._num_measurements: + raise ValueError("datacube intensities must be the same size.") + if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") if self._positions_mask.ndim == 2: warnings.warn( @@ -358,313 +350,178 @@ def preprocess( UserWarning, ) self._positions_mask = np.tile( - self._positions_mask, (self._num_sim_measurements, 1, 1) + self._positions_mask, (self._num_measurements, 1, 1) ) - if self._positions_mask.dtype != "bool": - warnings.warn( - "`positions_mask` converted to `bool` array.", - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_sim_measurements + num_probes_per_tilt = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) - if force_com_shifts is None: - force_com_shifts = [None, None, None] - elif len(force_com_shifts) == self._num_sim_measurements: - force_com_shifts = list(force_com_shifts) else: - raise ValueError( - ( - "force_com_shifts must be a sequence of tuples " - "with the same length as the datasets." - ) - ) + self._positions_mask = [None] * self._num_measurements + num_probes_per_tilt = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_tilt = np.array(num_probes_per_tilt) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if probe_roi_shape is not None: + roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + + self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) + self._region_of_interest_shape = np.array(roi_shape) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if self._scan_positions is None: + self._scan_positions = [None] * self._num_measurements # Ensure plot_center_of_mass is not in kwargs kwargs.pop("plot_center_of_mass", None) - # 1st measurement sets rotation angle and transposition - ( - measurement_0, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[0], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[0], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[0], - ) - - intensities_0 = self._extract_intensities_and_calibrations_from_datacube( - measurement_0, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_0, - com_measured_y_0, - com_fitted_x_0, - com_fitted_y_0, - com_normalized_x_0, - com_normalized_y_0, - ) = self._calculate_intensities_center_of_mass( - intensities_0, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[0], - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - _com_x_0, - _com_y_0, - com_x_0, - com_y_0, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_0, - com_measured_y_0, - com_normalized_x_0, - com_normalized_y_0, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - amplitudes_0, - mean_diffraction_intensity_0, - ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, - crop_patterns, - self._positions_mask[0], - ) - - # explicitly delete namescapes - del ( - intensities_0, - com_measured_x_0, - com_measured_y_0, - com_fitted_x_0, - com_fitted_y_0, - com_normalized_x_0, - com_normalized_y_0, - _com_x_0, - _com_y_0, - com_x_0, - com_y_0, - ) - - # 2nd measurement - ( - measurement_1, - _, - _, - force_com_shifts[1], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[1], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[1], - ) - - intensities_1 = self._extract_intensities_and_calibrations_from_datacube( - measurement_1, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_1, - com_measured_y_1, - com_fitted_x_1, - com_fitted_y_1, - com_normalized_x_1, - com_normalized_y_1, - ) = self._calculate_intensities_center_of_mass( - intensities_1, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[1], - ) - - ( - _, - _, - _com_x_1, - _com_y_1, - com_x_1, - com_y_1, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_1, - com_measured_y_1, - com_normalized_x_1, - com_normalized_y_1, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=np.rad2deg(self._rotation_best_rad), - force_com_transpose=self._rotation_best_transpose, - **kwargs, - ) - - ( - amplitudes_1, - mean_diffraction_intensity_1, - ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, - crop_patterns, - self._positions_mask[1], - ) - - # explicitly delete namescapes - del ( - intensities_1, - com_measured_x_1, - com_measured_y_1, - com_fitted_x_1, - com_fitted_y_1, - com_normalized_x_1, - com_normalized_y_1, - _com_x_1, - _com_y_1, - com_x_1, - com_y_1, - ) + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="measurement", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first measurement + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + ) - # Optionally, 3rd measurement - if self._num_sim_measurements == 3: - ( - measurement_2, - _, - _, - force_com_shifts[2], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[2], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[2], - ) - - intensities_2 = self._extract_intensities_and_calibrations_from_datacube( - measurement_2, + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], require_calibrations=True, force_scan_sampling=force_scan_sampling, force_angular_sampling=force_angular_sampling, force_reciprocal_sampling=force_reciprocal_sampling, ) + # calculate CoM ( - com_measured_x_2, - com_measured_y_2, - com_fitted_x_2, - com_fitted_y_2, - com_normalized_x_2, - com_normalized_y_2, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, ) = self._calculate_intensities_center_of_mass( - intensities_2, + intensities, dp_mask=self._dp_mask, fit_function=fit_function, - com_shifts=force_com_shifts[2], + com_shifts=force_com_shifts[index], ) - ( - _, - _, - _com_x_2, - _com_y_2, - com_x_2, - com_y_2, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_2, - com_measured_y_2, - com_normalized_x_2, - com_normalized_y_2, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=np.rad2deg(self._rotation_best_rad), - force_com_transpose=self._rotation_best_transpose, - **kwargs, - ) + # estimate rotation / transpose using first measurement + if index == 0: + # silence warnings to play nice with progress bar + verbose = self._verbose + self._verbose = False + ( + self._rotation_best_rad, + self._rotation_best_transpose, + _com_x, + _com_y, + com_x, + com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + com_measured_x, + com_measured_y, + com_normalized_x, + com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=False, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + self._verbose = verbose + + # corner-center amplitudes + idx_start = self._cum_probes_per_tilt[index] + idx_end = self._cum_probes_per_tilt[index + 1] ( - amplitudes_2, - mean_diffraction_intensity_2, + self._amplitudes[idx_start:idx_end], + mean_diffraction_intensity_temp, + self._crop_mask, ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], crop_patterns, - self._positions_mask[2], ) - # explicitly delete namescapes - del ( - intensities_2, - com_measured_x_2, - com_measured_y_2, - com_fitted_x_2, - com_fitted_y_2, - com_normalized_x_2, - com_normalized_y_2, - _com_x_2, - _com_y_2, - com_x_2, - com_y_2, - ) - - self._amplitudes = (amplitudes_0, amplitudes_1, amplitudes_2) - self._mean_diffraction_intensity = ( - mean_diffraction_intensity_0 - + mean_diffraction_intensity_1 - + mean_diffraction_intensity_2 - ) / 3 - - del amplitudes_0, amplitudes_1, amplitudes_2 - - else: - self._amplitudes = (amplitudes_0, amplitudes_1) - self._mean_diffraction_intensity = ( - mean_diffraction_intensity_0 + mean_diffraction_intensity_1 - ) / 2 - - del amplitudes_0, amplitudes_1 + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes[0].shape[0] - self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) + del ( + intensities, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask[0] - ) # TO-DO: generaltize to per-dataset probe positions + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + ) # handle semiangle specified in pixels if self._semiangle_cutoff_pixels: @@ -673,114 +530,61 @@ def preprocess( ) # Object Initialization + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + ) + if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - object_e = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - object_e = xp.ones((p, q), dtype=xp.complex64) - object_m = xp.zeros((p, q), dtype=xp.float32) + self._object = xp.full((2,) + obj.shape, obj) else: - if self._object_type == "potential": - object_e = xp.asarray(self._object[0], dtype=xp.float32) - elif self._object_type == "complex": - object_e = xp.asarray(self._object[0], dtype=xp.complex64) - object_m = xp.asarray(self._object[1], dtype=xp.float32) + self._object = obj - self._object = (object_e, object_m) - self._object_initial = (object_e.copy(), object_m.copy()) + self._object_initial = self._object.copy() self._object_type_initial = self._object_type - self._object_shape = self._object[0].shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_tilt[index] + idx_end = self._cum_probes_per_tilt[index + 1] + self._positions_px = self._positions_px_all[idx_start:idx_end] + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= ( + self._positions_px_com - xp.array(self._object_shape) / 2 + ) + self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, ) - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + del self._probe_init + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -790,25 +594,37 @@ def preprocess( )._evaluate_ctf() # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) + idx_end = self._cum_probes_per_tilt[1] + self._positions_px = self._positions_px_all[0:idx_end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + shifted_probes = fft_shift( + self._probes_all[0], self._positions_px_fractional, xp + ) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + # initialize object_fov_mask if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - power=2, + power=power, chroma_boost=chroma_boost, ) @@ -835,10 +651,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -846,7 +659,7 @@ def preprocess( ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap="Greys_r", + cmap="gray", ) ax2.scatter( self.positions[:, 1], diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index e45aac081..2fb4397c7 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1474,7 +1474,7 @@ def _return_single_probe(self, probe=None): for pr in _probes: probe += pr - return probe / self._num_tilts + return probe / len(_probes) @property def _probe(self): diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_overlap_tomography.py rename to py4DSTEM/process/phase/iterative_ptychographic_tomography.py index a3e80046a..de26667e5 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -1,6 +1,6 @@ """ Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely overlap tomography. +namely joint ptychographic tomography. """ import warnings @@ -49,7 +49,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class OverlapTomographicReconstruction( +class PtychographicTomographyReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, @@ -66,7 +66,7 @@ class OverlapTomographicReconstruction( PtychographicReconstruction, ): """ - Overlap Tomographic Reconstruction Class. + Ptychographic Tomography Reconstruction Class. List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) Reconstructed probe dimensions : (Sx,Sy) @@ -146,7 +146,7 @@ def __init__( initial_scan_positions: Sequence[np.ndarray] = None, verbose: bool = True, device: str = "cpu", - name: str = "overlap-tomographic_reconstruction", + name: str = "ptychographic-tomography_reconstruction", **kwargs, ): Custom.__init__(self, name=name) @@ -353,6 +353,7 @@ def preprocess( roi_shape = diffraction_intensities_shape if probe_roi_shape is not None: roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) self._region_of_interest_shape = np.array(roi_shape) @@ -494,7 +495,6 @@ def preprocess( self._positions_px -= ( self._positions_px_com - xp.array(self._object_shape) / 2 ) - self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() From 88b61654ec7fddd2d6eb57a2aa867406842308a9 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 11:31:19 -0800 Subject: [PATCH 048/128] reconstruct functionality for iterative magnetic Former-commit-id: 6223139d9e8d37884319ee8e0779c2682539387c --- .../phase/iterative_magnetic_ptychography.py | 2071 ++++------------- .../phase/iterative_ptychographic_methods.py | 4 +- .../iterative_ptychographic_tomography.py | 40 +- 3 files changed, 470 insertions(+), 1645 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py index 545a2c5e7..b97820e04 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -683,7 +683,7 @@ def preprocess( return self - def _warmup_overlap_projection(self, current_object, current_probe): + def _overlap_projection(self, current_object, shifted_probes): """ Ptychographic overlap projection method. @@ -706,1027 +706,48 @@ def _warmup_overlap_projection(self, current_object, current_probe): xp = self._xp - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - electrostatic_obj, _ = current_object - - if self._object_type == "potential": - complex_object = xp.exp(1j * electrostatic_obj) - else: - complex_object = electrostatic_obj - - electrostatic_obj_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - object_patches = (electrostatic_obj_patches, None) - overlap = (shifted_probes * electrostatic_obj_patches, None) - - return shifted_probes, object_patches, overlap - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - electrostatic_obj, magnetic_obj = current_object - - if self._object_type == "potential": - complex_object_e = xp.exp(1j * electrostatic_obj) - else: - complex_object_e = electrostatic_obj - - complex_object_m = xp.exp(1j * magnetic_obj) - - electrostatic_obj_patches = complex_object_e[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - magnetic_obj_patches = complex_object_m[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - object_patches = (electrostatic_obj_patches, magnetic_obj_patches) - - if self._sim_recon_mode == 0: - overlap_reverse = ( - shifted_probes - * electrostatic_obj_patches - * xp.conj(magnetic_obj_patches) - ) - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_reverse, overlap_forward) - elif self._sim_recon_mode == 1: - overlap_reverse = ( - shifted_probes - * electrostatic_obj_patches - * xp.conj(magnetic_obj_patches) - ) - overlap_neutral = shifted_probes * electrostatic_obj_patches - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_reverse, overlap_neutral, overlap_forward) - else: - overlap_neutral = shifted_probes * electrostatic_obj_patches - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_neutral, overlap_forward) - - return shifted_probes, object_patches, overlap - - def _warmup_gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - - fourier_overlap = xp.fft.fft2(overlap[0]) - error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes[0] * xp.exp( - 1j * xp.angle(fourier_overlap) - ) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = (modified_overlap - overlap[0],) + (None,) * ( - self._num_sim_measurements - 1 - ) - - return exit_waves, error - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - - error = 0.0 - exit_waves = [] - for amp, overl in zip(amplitudes, overlap): - fourier_overl = xp.fft.fft2(overl) - error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2) - - fourier_modified_overl = amp * xp.exp(1j * xp.angle(fourier_overl)) - modified_overl = xp.fft.ifft2(fourier_modified_overl) - - exit_waves.append(modified_overl - overl) - - error /= len(exit_waves) - exit_waves = tuple(exit_waves) - - return exit_waves, error - - def _warmup_projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - exit_wave = exit_waves[0] - - if exit_wave is None: - exit_wave = overlap[0].copy() - - fourier_overlap = xp.fft.fft2(overlap[0]) - error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap[0] + projection_y * exit_wave - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes[0] * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_wave = ( - projection_x * exit_wave - + projection_a * overlap[0] - + projection_b * projected_factor - ) - - exit_waves = (exit_wave,) + (None,) * (self._num_sim_measurements - 1) - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - error = 0.0 - _exit_waves = [] - for amp, overl, exit_wave in zip(amplitudes, overlap, exit_waves): - if exit_wave is None: - exit_wave = overl.copy() - - fourier_overl = xp.fft.fft2(overl) - error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2) - - factor_to_be_projected = projection_c * overl + projection_y * exit_wave - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amp * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - _exit_waves.append( - projection_x * exit_wave - + projection_a * overl - + projection_b * projected_factor - ) - - error /= len(_exit_waves) - exit_waves = tuple(_exit_waves) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - warmup_iteration, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - if warmup_iteration: - shifted_probes, object_patches, overlap = self._warmup_overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._warmup_projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._warmup_gradient_descent_fourier_projection( - amplitudes, overlap - ) - - else: - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _warmup_gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, _ = current_object - electrostatic_obj_patches, _ = object_patches - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(electrostatic_obj_patches) - * xp.conj(shifted_probes) - * exit_waves[0] - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves[0] - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(electrostatic_obj_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(electrostatic_obj_patches) * exit_waves[0], - axis=0, - ) - * object_normalization - ) - - return (electrostatic_obj, None), current_probe - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, magnetic_obj = current_object - probe_conj = xp.conj(shifted_probes) - - electrostatic_obj_patches, magnetic_obj_patches = object_patches - electrostatic_conj = xp.conj(electrostatic_obj_patches) - magnetic_conj = xp.conj(magnetic_obj_patches) - - probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches) - - probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( - probe_electrostatic_abs**2 - ) - probe_electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 - + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 - ) - - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2 - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) - - if self._sim_recon_mode > 0: - probe_abs = xp.abs(shifted_probes) - probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - exit_waves_reverse, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - - elif self._object_type == "complex": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - ) - / 2 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - 1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - - elif self._sim_recon_mode == 1: - exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_magnetic_normalization - ) - / 3 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_neutral - ) - ) - * probe_normalization - ) - / 3 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 3 - ) - - elif self._object_type == "complex": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - ) - / 3 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves_neutral - ) - * probe_normalization - / 3 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 3 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - 1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - - else: - exit_waves_neutral, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_neutral - ) - ) - * probe_normalization - ) - / 2 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - - elif self._object_type == "complex": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves_neutral - ) - * probe_normalization - / 2 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 3 - ) - - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - electrostatic_obj_patches * magnetic_obj_patches - ) - electrostatic_magnetic_normalization = xp.sum( - electrostatic_magnetic_abs**2, - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) - - if self._sim_recon_mode > 0: - electrostatic_abs = xp.abs(electrostatic_obj_patches) - electrostatic_normalization = xp.sum( - electrostatic_abs**2, - axis=0, - ) - electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 3 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - else: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 2 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_object = (electrostatic_obj, magnetic_obj) - - return current_object, current_probe - - def _warmup_projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, _ = current_object - electrostatic_obj_patches, _ = object_patches - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves[0] - ) - * probe_normalization + object_patches = xp.empty( + (self._num_measurements,) + shifted_probes.shape, dtype=xp.complex64 ) + object_patches[0] = complex_object[ + 0, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + ] + object_patches[1] = complex_object[ + 1, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + ] - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(electrostatic_obj_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) + overlap_base = shifted_probes * object_patches[0] - current_probe = ( - xp.sum( - xp.conj(electrostatic_obj_patches) * exit_waves[0], - axis=0, - ) - * object_normalization - ) + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + overlap = overlap_base * xp.conj(object_patches[1]) + case (0, 1) | (1, 2) | (2, 1): # forward + overlap = overlap_base * object_patches[1] + case (1, 1) | (2, 0): # neutral + overlap = overlap_base + case _: + raise ValueError() - return (electrostatic_obj, None), current_probe + return shifted_probes, object_patches, overlap - def _projection_sets_adjoint( + def _gradient_descent_adjoint( self, current_object, current_probe, object_patches, shifted_probes, exit_waves, + step_size, normalization_min, fix_probe, ): """ - Ptychographic adjoint operator for DM_AP and RAAR methods. + Ptychographic adjoint operator for GD method. Computes object and probe update steps. Parameters @@ -1741,6 +762,8 @@ def _projection_sets_adjoint( fractionally-shifted probes exit_waves:np.ndarray Updated exit_waves + step_size: float, optional + Update step size normalization_min: float, optional Probe normalization minimum as a fraction of the maximum overlap intensity fix_probe: bool, optional @@ -1753,19 +776,12 @@ def _projection_sets_adjoint( updated_probe: np.ndarray Updated probe estimate """ - xp = self._xp - electrostatic_obj, magnetic_obj = current_object - probe_conj = xp.conj(shifted_probes) - - electrostatic_obj_patches, magnetic_obj_patches = object_patches - electrostatic_conj = xp.conj(electrostatic_obj_patches) - magnetic_conj = xp.conj(magnetic_obj_patches) - - probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches) + probe_conj = xp.conj(shifted_probes) # P* + electrostatic_conj = xp.conj(object_patches[0]) # V* = exp(-i v) + probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( probe_electrostatic_abs**2 ) @@ -1775,6 +791,7 @@ def _projection_sets_adjoint( + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 ) + probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( probe_magnetic_abs**2 ) @@ -1784,104 +801,10 @@ def _projection_sets_adjoint( + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 ) - if self._sim_recon_mode > 0: - probe_abs = xp.abs(shifted_probes) - - probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - exit_waves_reverse, exit_waves_forward = exit_waves - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - / 2 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj = xp.conj( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_reverse - ) - * probe_electrostatic_normalization - / 2 - ) - - magnetic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_forward - ) - * probe_electrostatic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - / 3 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts(probe_conj * exit_waves_neutral) - * probe_normalization - / 3 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 3 - ) - - magnetic_obj = xp.conj( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_reverse - ) - * probe_electrostatic_normalization - / 2 - ) - - magnetic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_forward - ) - * probe_electrostatic_normalization - / 2 - ) - - else: - raise NotImplementedError() - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - electrostatic_obj_patches * magnetic_obj_patches - ) - + electrostatic_magnetic_abs = xp.abs(object_patches[0] * object_patches[1]) electrostatic_magnetic_normalization = xp.sum( - (electrostatic_magnetic_abs**2), + electrostatic_magnetic_abs**2, axis=0, ) electrostatic_magnetic_normalization = 1 / xp.sqrt( @@ -1891,10 +814,10 @@ def _projection_sets_adjoint( ** 2 ) - if self._sim_recon_mode > 0: - electrostatic_abs = xp.abs(electrostatic_obj_patches) + if self._recon_mode > 0: + electrostatic_abs = xp.abs(object_patches[0]) electrostatic_normalization = xp.sum( - (electrostatic_abs**2), + electrostatic_abs**2, axis=0, ) electrostatic_normalization = 1 / xp.sqrt( @@ -1903,167 +826,139 @@ def _projection_sets_adjoint( + (normalization_min * xp.max(electrostatic_normalization)) ** 2 ) - if self._sim_recon_mode == 0: - current_probe = ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + if self._object_type == "potential": + # -i exp(-i v) exp(i m) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * object_patches[1] + * electrostatic_conj + * probe_conj + * exit_waves + ) ) - * electrostatic_magnetic_normalization - / 2 - ) - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) + # i exp(-i v) exp(i m) P* + magnetic_update = -electrostatic_update - elif self._sim_recon_mode == 1: - current_probe = ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, + else: + # M P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * object_patches[1] * exit_waves ) - * electrostatic_magnetic_normalization - / 3 - ) - current_probe += ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, + # V* P* + magnetic_update = xp.conj( + self._sum_overlapping_patches_bincounts( + probe_conj * electrostatic_conj * exit_waves + ) ) - * electrostatic_normalization - / 3 - ) - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization ) - else: - current_probe = ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 2 + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization ) - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, + if not fix_probe: + # M V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * object_patches[1] * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization ) - * electrostatic_magnetic_normalization - / 2 - ) - current_object = (electrostatic_obj, magnetic_obj) + case (0, 1) | (1, 2) | (2, 1): # forward + magnetic_conj = xp.conj(object_patches[1]) # M* = exp(-i m) - return current_object, current_probe + if self._object_type == "potential": + # -i exp(-i v) exp(-i m) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves + ) + ) - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - warmup_iteration: bool, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. + # -i exp(-i v) exp(-i m) P* + magnetic_update = electrostatic_update - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated + else: + # M* P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_conj * exit_waves + ) - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ + # V* P* + magnetic_update = self._sum_overlapping_patches_bincounts( + probe_conj * electrostatic_conj * exit_waves + ) - if warmup_iteration: - if use_projection_scheme: - current_object, current_probe = self._warmup_projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization ) - else: - current_object, current_probe = self._warmup_gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization ) - else: - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, + if not fix_probe: + # M* V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_conj * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) + probe_normalization = self._sum_overlapping_patches_bincounts( + probe_abs**2 ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + # -i exp(-i v) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real(-1j * electrostatic_conj * probe_conj * exit_waves) + ) + + else: + # P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * exit_waves + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_normalization ) + if not fix_probe: + # V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_normalization + ) + + case _: + raise ValueError() + return current_object, current_probe def _constraints( @@ -2098,7 +993,6 @@ def _constraints( tv_denoise, tv_denoise_weight, tv_denoise_inner_iter, - warmup_iteration, object_positivity, shrinkage_rad, object_mask, @@ -2191,65 +1085,60 @@ def _constraints( Constrained positions estimate """ - electrostatic_obj, magnetic_obj = current_object + # object constraints + # smoothness if gaussian_filter: - electrostatic_obj = self._object_gaussian_constraint( - electrostatic_obj, gaussian_filter_sigma_e, pure_phase_object + current_object[0] = self._object_gaussian_constraint( + current_object[0], gaussian_filter_sigma_e, pure_phase_object + ) + current_object[1] = self._object_gaussian_constraint( + current_object[1], gaussian_filter_sigma_m, True ) - if not warmup_iteration: - magnetic_obj = self._object_gaussian_constraint( - magnetic_obj, - gaussian_filter_sigma_m, - pure_phase_object, - ) - if butterworth_filter: - electrostatic_obj = self._object_butterworth_constraint( - electrostatic_obj, + current_object[0] = self._object_butterworth_constraint( + current_object[0], q_lowpass_e, q_highpass_e, butterworth_order, ) - if not warmup_iteration: - magnetic_obj = self._object_butterworth_constraint( - magnetic_obj, - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - - if self._object_type == "complex": - magnetic_obj = magnetic_obj.real + current_object[1] = self._object_butterworth_constraint( + current_object[1], + q_lowpass_m, + q_highpass_m, + butterworth_order, + ) if tv_denoise: - electrostatic_obj = self._object_denoise_tv_pylops( - electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], tv_denoise_weight, tv_denoise_inner_iter ) - if not warmup_iteration: - magnetic_obj = self._object_denoise_tv_pylops( - magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter - ) - + # L1-norm pushing vacuum to zero if shrinkage_rad > 0.0 or object_mask is not None: - electrostatic_obj = self._object_shrinkage_constraint( - electrostatic_obj, + current_object[0] = self._object_shrinkage_constraint( + current_object[0], shrinkage_rad, object_mask, ) + # amplitude threshold (complex) or positivity (potential) if self._object_type == "complex": - electrostatic_obj = self._object_threshold_constraint( - electrostatic_obj, pure_phase_object + current_object[0] = self._object_threshold_constraint( + current_object[0], pure_phase_object + ) + current_object[1] = self._object_threshold_constraint( + current_object[1], True ) elif object_positivity: - electrostatic_obj = self._object_positivity_constraint(electrostatic_obj) + current_object[0] = self._object_positivity_constraint(current_object[0]) - current_object = (electrostatic_obj, magnetic_obj) + # probe constraints + # CoM corner-centering if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) + # Fourier amplitude (aperture) constraints if fix_probe_aperture: current_probe = self._probe_aperture_constraint( current_probe, @@ -2262,6 +1151,7 @@ def _constraints( constrain_probe_fourier_amplitude_constant_intensity, ) + # Fourier phase (aberrations) fitting if fit_probe_aberrations: current_probe = self._probe_aberration_fitting_constraint( current_probe, @@ -2269,6 +1159,7 @@ def _constraints( fit_probe_aberrations_max_radial_order, ) + # Real-space amplitude constraint if constrain_probe_amplitude: current_probe = self._probe_amplitude_constraint( current_probe, @@ -2276,11 +1167,15 @@ def _constraints( constrain_probe_amplitude_relative_width, ) + # position constraints if not fix_positions: + # CoM centering current_positions = self._positions_center_of_mass_constraint( current_positions ) + # global affine transformation + # TO-DO: generalize to higher-order basis? if global_affine_transformation: current_positions = self._positions_affine_transformation_constraint( self._positions_px_initial, current_positions @@ -2290,7 +1185,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -2304,7 +1199,6 @@ def reconstruct( pure_phase_object_iter: int = 0, fix_com: bool = True, fix_probe_iter: int = 0, - warmup_iter: int = 0, fix_probe_aperture_iter: int = 0, constrain_probe_amplitude_iter: int = 0, constrain_probe_amplitude_relative_radius: float = 0.5, @@ -2313,7 +1207,8 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, + max_position_update_distance: float = None, + max_position_total_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma_e: float = None, gaussian_filter_sigma_m: float = None, @@ -2335,6 +1230,7 @@ def reconstruct( fix_potential_baseline: bool = True, switch_object_iter: int = np.inf, store_iterations: bool = False, + collective_measurement_updates: bool = True, progress_bar: bool = True, reset: bool = None, ): @@ -2393,9 +1289,10 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma_e: float @@ -2439,6 +1336,8 @@ def reconstruct( 'potential' and 'complex' store_iterations: bool, optional If True, reconstructed objects and probes are stored at each iteration + collective_measurement_updates: bool + if True perform collective updates for all measurements progress_bar: bool, optional If True, reconstruction progress is displayed reset: bool, optional @@ -2452,164 +1351,42 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) - if use_projection_scheme and self._sim_recon_mode == 2: + if use_projection_scheme: raise NotImplementedError( - "simultaneous_measurements_mode == '0+' and projection set algorithms are currently incompatible." + "Magnetic ptychography currently only implemented for gradient descent." ) if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) + self._report_reconstruction_summary( + max_iter, + switch_object_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) if max_batch_size is not None: xp.random.seed(seed_random) @@ -2617,41 +1394,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = ( - self._object_initial[0].copy(), - self._object_initial[1].copy(), - ) - self._probe = self._probe_initial.copy() - self.error_iterations = [] - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = (None,) * self._num_sim_measurements - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = (None,) * self._num_sim_measurements + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) if gaussian_filter_sigma_m is None: gaussian_filter_sigma_m = gaussian_filter_sigma_e @@ -2671,158 +1414,248 @@ def reconstruct( if a0 == switch_object_iter: if self._object_type == "potential": self._object_type = "complex" - self._object = (xp.exp(1j * self._object[0]), self._object[1]) - elif self._object_type == "complex": + self._object = xp.exp(1j * self._object) + else: self._object_type = "potential" - self._object = (xp.angle(self._object[0]), self._object[1]) + self._object = xp.angle(self._object) - if a0 == warmup_iter: - self._object = (self._object[0], self._object_initial[1].copy()) + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + measurement_indices = np.arange(self._num_measurements) + np.random.shuffle(measurement_indices) - amps = [] - for amplitudes in self._amplitudes: - amps.append(amplitudes[shuffled_indices[start:end]]) - amplitudes = tuple(amps) + for measurement_index in measurement_indices: + self._active_measurement_index = measurement_index - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - warmup_iteration=a0 < warmup_iter, - use_projection_scheme=use_projection_scheme, - projection_a=projection_a, - projection_b=projection_b, - projection_c=projection_c, - ) + measurement_error = 0.0 - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - warmup_iteration=a0 < warmup_iter, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + start_idx = self._cum_probes_per_tilt[self._active_measurement_index] + end_idx = self._cum_probes_per_tilt[self._active_measurement_index + 1] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + unshuffled_indices[shuffled_indices] = np.arange( + num_diffraction_patterns ) - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object[0], - shifted_probes, - overlap[0], - amplitudes[0], - self._positions_px, - positions_step_size, - constrain_position_distance, + positions_px = self._positions_px_all[start_idx:end_idx].copy()[ + shuffled_indices + ] + initial_positions_px = self._positions_px_initial_all[ + start_idx:end_idx + ].copy()[shuffled_indices] + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_initial = initial_positions_px[start:end] + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px ) - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - warmup_iteration=a0 < warmup_iter, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() - self.error_iterations.append(error.item()) - if store_iterations: - if a0 < warmup_iter: - self.object_iterations.append( - (asnumpy(self._object[0].copy()), None) + amplitudes = self._amplitudes[start_idx:end_idx][ + shuffled_indices[start:end] + ] + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + _probe, + amplitudes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + projection_a=projection_a, + projection_b=projection_b, + projection_c=projection_c, ) - else: - self.object_iterations.append( - ( - asnumpy(self._object[0].copy()), - asnumpy(self._object[1].copy()), + + # adjoint operator + object_update, _probe = self._adjoint( + self._object.copy(), + _probe, + object_patches, + shifted_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + object_update -= self._object + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + shifted_probes, + overlap, + amplitudes, + self._positions_px, + self._positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, ) + + measurement_error += batch_error + + if collective_measurement_updates: + collective_object += object_update + else: + self._object += object_update + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + self._positions_px_all[start_idx:end_idx] = positions_px.copy()[ + unshuffled_indices + ] + + if not collective_measurement_updates: + ( + self._object, + _probe, + self._positions_px_all[start_idx:end_idx], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[start_idx:end_idx], + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", ) - self.probe_iterations.append(self.probe_centered) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + self._object, _, _ = self._constraints( + self._object, + None, + None, + fix_com=False, + constrain_probe_amplitude=False, + constrain_probe_amplitude_relative_radius=None, + constrain_probe_amplitude_relative_width=None, + constrain_probe_fourier_amplitude=False, + constrain_probe_fourier_amplitude_max_width_pixels=None, + constrain_probe_fourier_amplitude_constant_intensity=None, + fit_probe_aberrations=False, + fit_probe_aberrations_max_angular_order=None, + fit_probe_aberrations_max_radial_order=None, + fix_probe_aperture=False, + initial_probe_aperture=None, + fix_positions=True, + global_affine_transformation=None, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append( + [ + asnumpy(self._return_centered_probe(pr.copy())) + for pr in self._probes_all + ] + ) # store result - if a0 < warmup_iter: - self.object = (asnumpy(self._object[0]), None) - else: - self.object = (asnumpy(self._object[0]), asnumpy(self._object[1])) + self.object = asnumpy(self._object) self.probe = self.probe_centered self.error = error.item() diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 2fb4397c7..391c45717 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1431,7 +1431,7 @@ def _reset_reconstruction( self._object_type = self._object_type_initial if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts + self._exit_waves = [None] * len(self._probes_all) else: self._exit_waves = None @@ -1454,7 +1454,7 @@ def _reset_reconstruction( else: self.error_iterations = [] if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts + self._exit_waves = [None] * len(self._probes_all) else: self._exit_waves = None diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index de26667e5..d341ffdc9 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -1055,10 +1055,6 @@ def reconstruct( max_batch_size, ) - # batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - if max_batch_size is not None: xp.random.seed(seed_random) else: @@ -1184,7 +1180,7 @@ def reconstruct( if a0 >= fix_positions_iter: positions_px[start:end] = self._position_correction( object_sliced, - _probe, + shifted_probes, overlap, amplitudes, self._positions_px, @@ -1281,30 +1277,26 @@ def reconstruct( ( self._object, - _probe, + _, _, ) = self._constraints( self._object, - _probe, None, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=_probe_initial_aperture, + None, + fix_com=False, + constrain_probe_amplitude=False, + constrain_probe_amplitude_relative_radius=None, + constrain_probe_amplitude_relative_width=None, + constrain_probe_fourier_amplitude=False, + constrain_probe_fourier_amplitude_max_width_pixels=None, + constrain_probe_fourier_amplitude_constant_intensity=None, + fit_probe_aberrations=False, + fit_probe_aberrations_max_angular_order=None, + fit_probe_aberrations_max_radial_order=None, + fix_probe_aperture=False, + initial_probe_aperture=None, fix_positions=True, - global_affine_transformation=global_affine_transformation, + global_affine_transformation=None, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, From b6c85795a4109c5413374d3d72ccafc9ee714d76 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 13:02:51 -0800 Subject: [PATCH 049/128] splitting up constraints Former-commit-id: 85d1b76b9aca83e01a0842c8f3ff01aa0d619717 --- .../process/phase/iterative_base_class.py | 9 + ...tive_mixedstate_multislice_ptychography.py | 209 +----------- .../iterative_mixedstate_ptychography.py | 172 +--------- .../iterative_multislice_ptychography.py | 219 +----------- .../iterative_ptychographic_constraints.py | 322 ++++++++++++++++++ .../iterative_ptychographic_tomography.py | 177 +--------- .../iterative_singleslice_ptychography.py | 195 +---------- 7 files changed, 336 insertions(+), 967 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 3a066de4e..f06b3aa29 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1980,6 +1980,15 @@ def _report_reconstruction_summary( ) ) + def _constraints(self, current_object, current_probe, current_positions, **kwargs): + """Wrapper function for all classes to inherit""" + + current_object = self._object_constraints(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints(current_positions, **kwargs) + + return current_object, current_probe, current_positions + @property def angular_sampling(self): """Angular sampling [mrad]""" diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index b00fea11b..79930c740 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -647,213 +647,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - orthogonalize_probe, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: int - If not None, pads object at top and bottom with this many zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - orthogonalize_probe: bool - If True, probe will be orthogonalized - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, - kz_regularization_gamma, - z_padding=1, - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - z_padding=1, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - padding=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -1058,8 +851,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # batching diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 317f9316c..765f6b8b6 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -550,176 +550,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - orthogonalize_probe, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray - initial probe aperture to use in replacing probe fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - orthogonalize_probe: bool - If True, probe will be orthogonalized - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -905,8 +735,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # Batching diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 6bb14d91a..91f9ea076 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -622,223 +622,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: int - if not None, pads object at top and bottom with this many zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, - kz_regularization_gamma, - z_padding=1, - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - z_padding=1, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - padding=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -1042,8 +825,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # Batching diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 538b9d7aa..5181984ed 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -408,6 +408,61 @@ def _object_denoise_tv_chambolle( return updated_object + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """ObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + class Object2p5DConstraintsMixin: """ @@ -591,6 +646,85 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + identical_slices, + kz_regularization_filter, + kz_regularization_gamma, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object2p5DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, + kz_regularization_gamma, + z_padding=1, + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + z_padding=1, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + padding=tv_denoise_pad_chambolle, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + class Object3DConstraintsMixin: """ @@ -699,6 +833,58 @@ def _object_butterworth_constraint( return current_object + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object3DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object=False + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # Positivity + if object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + class ProbeConstraintsMixin: """ @@ -892,6 +1078,60 @@ def _probe_aberration_fitting_constraint( return current_probe + def _probe_constraints( + self, + current_probe, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + **kwargs, + ): + """ProbeConstraints wrapper function""" + + # CoM corner-centering + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( + current_probe, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( + current_probe, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( + current_probe, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + return current_probe + class ProbeMixedConstraintsMixin: """ @@ -961,6 +1201,67 @@ def _probe_orthogonalization_constraint(self, current_probe): intensities_order = xp.argsort(intensities, axis=None)[::-1] return current_probe[intensities_order] + def _probe_constraints( + self, + current_probe, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + orthogonalize_probe, + **kwargs, + ): + """ProbeMixedConstraints wrapper function""" + + # CoM corner-centering + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_aberration_fitting_constraint( + current_probe[probe_idx], + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe[0] = self._probe_aperture_constraint( + current_probe[0], + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe[0] = self._probe_fourier_amplitude_constraint( + current_probe[0], + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_amplitude_constraint( + current_probe[probe_idx], + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + # Probe orthogonalization + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + return current_probe + class PositionsConstraintsMixin: """ @@ -1032,3 +1333,24 @@ def _positions_affine_transformation_constraint( current_positions = tf(initial_positions, origin=self._positions_px_com, xp=xp) return current_positions + + def _positions_constraints( + self, + current_positions, + fix_positions, + global_affine_transformation, + **kwargs, + ): + """PositionsConstraints wrapper function""" + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_positions diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index d341ffdc9..5b4892c5e 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -697,181 +697,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object=False - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -1051,8 +876,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) if max_batch_size is not None: diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index d218651c7..3a0f5d77b 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -529,199 +529,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - # object constraints - - # smoothness - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - # L1-norm pushing vacuum to zero - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - # amplitude threshold (complex) or positivity (potential) - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - # probe constraints - - # CoM corner-centering - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # Fourier amplitude (aperture) constraints - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - # Fourier phase (aberrations) fitting - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - # Real-space amplitude constraint - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - # position constraints - if not fix_positions: - # CoM centering - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - # global affine transformation - # TO-DO: generalize to higher-order basis? - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -906,8 +713,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # batching From d7b63cfccf3fba5ecb934336393b5852d20d5ce8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 13:12:09 -0800 Subject: [PATCH 050/128] bumping python version to 3.10 Former-commit-id: 360741afaa228f37239a2c8f4bfc4863d1932468 --- .github/workflows/check_install_dev.yml | 2 +- .github/workflows/check_install_main.yml | 2 +- .github/workflows/check_install_quick.yml | 4 ++-- setup.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml index 4e9d16f77..a960dc2f2 100644 --- a/.github/workflows/check_install_dev.yml +++ b/.github/workflows/check_install_dev.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] # include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest diff --git a/.github/workflows/check_install_main.yml b/.github/workflows/check_install_main.yml index a276cab17..d27278ba9 100644 --- a/.github/workflows/check_install_main.yml +++ b/.github/workflows/check_install_main.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest, windows-latest, macos-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] #include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest diff --git a/.github/workflows/check_install_quick.yml b/.github/workflows/check_install_quick.yml index f83ee0b73..0d20bd759 100644 --- a/.github/workflows/check_install_quick.yml +++ b/.github/workflows/check_install_quick.yml @@ -20,7 +20,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.12"] + python-version: ["3.10", "3.12"] # Currently no public runners available for this but this or arm64 should work next time # include: # - python-version: "3.10" @@ -42,4 +42,4 @@ jobs: python -c "import py4DSTEM; print(py4DSTEM.__version__)" # - name: Check machine arch # run: | - # python -c "import platform; print(platform.machine())" \ No newline at end of file + # python -c "import platform; print(platform.machine())" diff --git a/setup.py b/setup.py index 2d828289a..21e244a4c 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ author_email="ben.savitzky@gmail.com", license="GNU GPLv3", keywords="STEM 4DSTEM", - python_requires=">=3.9,<3.13", + python_requires=">=3.10", install_requires=[ "numpy >= 1.19", "scipy >= 1.5.2", From 070e3ba4312399b074990e0eff3cd289cb4d4865 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 13:36:05 -0800 Subject: [PATCH 051/128] correctly handling collective updates constraints Former-commit-id: f23d8cc83f33d5b7b50b3ee4dbc9689e53707c66 --- .../phase/iterative_magnetic_ptychography.py | 222 ++++-------------- .../iterative_ptychographic_tomography.py | 58 +++-- 2 files changed, 81 insertions(+), 199 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py index b97820e04..c2bbce64e 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -961,131 +961,28 @@ def _gradient_descent_adjoint( return current_object, current_probe - def _constraints( + def _object_constraints( self, current_object, - current_probe, - current_positions, pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, gaussian_filter, gaussian_filter_sigma_e, gaussian_filter_sigma_m, butterworth_filter, + butterworth_order, q_lowpass_e, q_lowpass_m, q_highpass_e, q_highpass_m, - butterworth_order, tv_denoise, tv_denoise_weight, tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, + **kwargs, ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - warmup_iteration: bool - If True, constraints electrostatic object only - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - # object constraints + """MagneticObjectNDConstraints wrapper function""" # smoothness if gaussian_filter: @@ -1132,56 +1029,7 @@ def _constraints( elif object_positivity: current_object[0] = self._object_positivity_constraint(current_object[0]) - # probe constraints - - # CoM corner-centering - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # Fourier amplitude (aperture) constraints - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - # Fourier phase (aberrations) fitting - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - # Real-space amplitude constraint - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - # position constraints - if not fix_positions: - # CoM centering - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - # global affine transformation - # TO-DO: generalize to higher-order basis? - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions + return current_object def reconstruct( self, @@ -1351,6 +1199,12 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + # set and report reconstruction method ( use_projection_scheme, @@ -1370,7 +1224,7 @@ def reconstruct( if use_projection_scheme: raise NotImplementedError( - "Magnetic ptychography currently only implemented for gradient descent." + "Magnetic ptychography is currently only implemented for gradient descent." ) if self._verbose: @@ -1543,7 +1397,38 @@ def reconstruct( unshuffled_indices ] - if not collective_measurement_updates: + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[ + start_idx:end_idx + ] = self._positions_constraints( + self._positions_px_all[start_idx:end_idx], + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions ( self._object, _probe, @@ -1601,24 +1486,9 @@ def reconstruct( if collective_measurement_updates: self._object += collective_object / self._num_measurements - self._object, _, _ = self._constraints( + # object only + self._object = self._object_constraints( self._object, - None, - None, - fix_com=False, - constrain_probe_amplitude=False, - constrain_probe_amplitude_relative_radius=None, - constrain_probe_amplitude_relative_width=None, - constrain_probe_fourier_amplitude=False, - constrain_probe_fourier_amplitude_max_width_pixels=None, - constrain_probe_fourier_amplitude_constant_intensity=None, - fit_probe_aberrations=False, - fit_probe_aberrations_max_angular_order=None, - fit_probe_aberrations_max_radial_order=None, - fix_probe_aperture=False, - initial_probe_aperture=None, - fix_positions=True, - global_affine_transformation=None, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index 5b4892c5e..c4374c93a 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -738,7 +738,7 @@ def reconstruct( tv_denoise_iter=np.inf, tv_denoise_weights=None, tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, + collective_tilt_updates: bool = True, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, @@ -1045,7 +1045,38 @@ def reconstruct( unshuffled_indices ] - if not collective_tilt_updates: + if collective_tilt_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[ + start_tilt:end_tilt + ] = self._positions_constraints( + self._positions_px_all[start_tilt:end_tilt], + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions ( self._object, _probe, @@ -1100,28 +1131,9 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - ( - self._object, - _, - _, - ) = self._constraints( + # object only + self._object = self._object_constraints( self._object, - None, - None, - fix_com=False, - constrain_probe_amplitude=False, - constrain_probe_amplitude_relative_radius=None, - constrain_probe_amplitude_relative_width=None, - constrain_probe_fourier_amplitude=False, - constrain_probe_fourier_amplitude_max_width_pixels=None, - constrain_probe_fourier_amplitude_constant_intensity=None, - fit_probe_aberrations=False, - fit_probe_aberrations_max_angular_order=None, - fit_probe_aberrations_max_radial_order=None, - fix_probe_aperture=False, - initial_probe_aperture=None, - fix_positions=True, - global_affine_transformation=None, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, From d2bfe726e0dd6ae3842d02c6298b5ee595ed3cab Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 15:57:56 -0800 Subject: [PATCH 052/128] finished with magnetic ptycho Former-commit-id: 1e4e59222d7e64fd95f1e394f4b58a684eac70ea --- .../phase/iterative_magnetic_ptychography.py | 425 +++++------------- .../phase/iterative_ptychographic_methods.py | 219 +++++---- .../iterative_ptychographic_tomography.py | 175 ++++---- .../iterative_ptychographic_visualizations.py | 20 +- 4 files changed, 333 insertions(+), 506 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py index c2bbce64e..01817e5fa 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -10,7 +10,11 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) try: import cupy as cp @@ -26,9 +30,9 @@ ProbeConstraintsMixin, ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, - ProbeListMethodsMixin, ProbeMethodsMixin, ) from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( @@ -50,8 +54,8 @@ class MagneticPtychographicReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, ObjectNDProbeMethodsMixin, - ProbeListMethodsMixin, ProbeMethodsMixin, ObjectNDMethodsMixin, PtychographicReconstruction, @@ -353,19 +357,19 @@ def preprocess( self._positions_mask, (self._num_measurements, 1, 1) ) - num_probes_per_tilt = np.insert( + num_probes_per_measurement = np.insert( self._positions_mask.sum(axis=(-2, -1)), 0, 0 ) else: self._positions_mask = [None] * self._num_measurements - num_probes_per_tilt = [0] + [dc.R_N for dc in self._datacube] - num_probes_per_tilt = np.array(num_probes_per_tilt) + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) # prepopulate relevant arrays self._mean_diffraction_intensity = [] - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) # calculate roi_shape @@ -388,12 +392,6 @@ def preprocess( # Ensure plot_center_of_mass is not in kwargs kwargs.pop("plot_center_of_mass", None) - # prepopulate relevant arrays - self._mean_diffraction_intensity = [] - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) - self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - # loop over DPs for preprocessing for index in tqdmnd( self._num_measurements, @@ -487,8 +485,8 @@ def preprocess( self._verbose = verbose # corner-center amplitudes - idx_start = self._cum_probes_per_tilt[index] - idx_end = self._cum_probes_per_tilt[index + 1] + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] ( self._amplitudes[idx_start:idx_end], mean_diffraction_intensity_temp, @@ -549,8 +547,8 @@ def preprocess( self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) for index in range(self._num_measurements): - idx_start = self._cum_probes_per_tilt[index] - idx_end = self._cum_probes_per_tilt[index + 1] + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] self._positions_px = self._positions_px_all[idx_start:idx_end] self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= ( @@ -563,6 +561,11 @@ def preprocess( self._positions_initial_all[:, 0] *= self.sampling[0] self._positions_initial_all[:, 1] *= self.sampling[1] + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + # initialize probe self._probes_all = [] self._probes_all_initial = [] @@ -594,7 +597,7 @@ def preprocess( )._evaluate_ctf() # overlaps - idx_end = self._cum_probes_per_tilt[1] + idx_end = self._cum_probes_per_measurement[1] self._positions_px = self._positions_px_all[0:idx_end] self._positions_px_fractional = self._positions_px - xp.round( self._positions_px @@ -662,8 +665,8 @@ def preprocess( cmap="gray", ) ax2.scatter( - self.positions[:, 1], - self.positions[:, 0], + self.positions[0, :, 1], + self.positions[0, :, 0], s=2.5, color=(1, 0, 0, 1), ) @@ -1256,6 +1259,9 @@ def reconstruct( if q_lowpass_m is None: q_lowpass_m = q_lowpass_e + if fix_positions_iter < 1: + fix_positions_iter = 1 # give position correction a chance + # main loop for a0 in tqdmnd( max_iter, @@ -1289,8 +1295,12 @@ def reconstruct( self._active_measurement_index ] - start_idx = self._cum_probes_per_tilt[self._active_measurement_index] - end_idx = self._cum_probes_per_tilt[self._active_measurement_index + 1] + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] num_diffraction_patterns = end_idx - start_idx shuffled_indices = np.arange(num_diffraction_patterns) @@ -1535,66 +1545,8 @@ def reconstruct( return self - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object[0]) - else: - obj = self.object[0] - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() def _visualize_last_iteration( self, @@ -1604,7 +1556,6 @@ def _visualize_last_iteration( plot_probe: bool, plot_fourier_probe: bool, remove_initial_probe_aberrations: bool, - padding: int, **kwargs, ): """ @@ -1625,39 +1576,35 @@ def _visualize_last_iteration( remove_initial_probe_aberrations: bool, optional If true, when plotting fourier probe, removes initial probe to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object """ + + asnumpy = self._asnumpy + figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") + chroma_boost = kwargs.pop("chroma_boost", 1) + # get scaled arrays + probe = self._return_single_probe() + obj = self.object_cropped if self._object_type == "complex": - obj_e = np.angle(self.object[0]) - obj_m = self.object[1] - else: - obj_e, obj_m = self.object + obj = np.angle(obj) - rotated_electrostatic = self._crop_rotate_object_fov(obj_e, padding=padding) - rotated_magnetic = self._crop_rotate_object_fov(obj_m, padding=padding) - rotated_shape = rotated_electrostatic.shape - - min_e = rotated_electrostatic.min() - max_e = rotated_electrostatic.max() - max_m = np.abs(rotated_magnetic).max() - min_m = -max_m - - vmin_e = kwargs.pop("vmin_e", min_e) - vmax_e = kwargs.pop("vmax_e", max_e) - vmin_m = kwargs.pop("vmin_m", min_m) - vmax_m = kwargs.pop("vmax_m", max_m) + vmin_e = kwargs.pop("vmin_e", None) + vmax_e = kwargs.pop("vmax_e", None) + obj[0], vmin_e, vmax_e = return_scaled_histogram_ordering( + obj[0], vmin_e, vmax_e + ) - chroma_boost = kwargs.pop("chroma_boost", 1) + _, _, _vmax_m = return_scaled_histogram_ordering(np.abs(obj[1])) + vmin_m = kwargs.pop("vmin_m", -_vmax_m) + vmax_m = kwargs.pop("vmax_m", _vmax_m) extent = [ 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], + self.sampling[1] * obj.shape[2], + self.sampling[0] * obj.shape[1], 0, ] @@ -1668,6 +1615,7 @@ def _visualize_last_iteration( self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, ] + elif plot_probe: probe_extent = [ 0, @@ -1684,26 +1632,29 @@ def _visualize_last_iteration( height_ratios=[4, 1], hspace=0.15, width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), 1, - 1, - (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]), ], wspace=0.35, ) + else: spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: if plot_probe or plot_fourier_probe: spec = GridSpec( ncols=3, nrows=1, width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), 1, - 1, - (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]), ], wspace=0.35, ) + else: spec = GridSpec(ncols=2, nrows=1) @@ -1711,10 +1662,10 @@ def _visualize_last_iteration( fig = plt.figure(figsize=figsize) if plot_probe or plot_fourier_probe: - # Electrostatic Object + # Object_e ax = fig.add_subplot(spec[0, 0]) im = ax.imshow( - rotated_electrostatic, + obj[0], extent=extent, cmap=cmap_e, vmin=vmin_e, @@ -1723,10 +1674,11 @@ def _visualize_last_iteration( ) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") + if self._object_type == "potential": - ax.set_title("Reconstructed electrostatic potential") + ax.set_title("Electrostatic potential") elif self._object_type == "complex": - ax.set_title("Reconstructed electrostatic phase") + ax.set_title("Electrostatic phase") if cbar: divider = make_axes_locatable(ax) @@ -1734,10 +1686,10 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - # Magnetic Object + # Object_m ax = fig.add_subplot(spec[0, 1]) im = ax.imshow( - rotated_magnetic, + obj[1], extent=extent, cmap=cmap_m, vmin=vmin_m, @@ -1746,7 +1698,11 @@ def _visualize_last_iteration( ) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Reconstructed magnetic potential") + + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") if cbar: divider = make_axes_locatable(ax) @@ -1757,21 +1713,26 @@ def _visualize_last_iteration( # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) probe_array = Complex2RGB( - probe_array, + probe, chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, power=2, chroma_boost=chroma_boost + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -1788,10 +1749,10 @@ def _visualize_last_iteration( add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: - # Electrostatic Object + # Object_e ax = fig.add_subplot(spec[0, 0]) im = ax.imshow( - rotated_electrostatic, + obj[0], extent=extent, cmap=cmap_e, vmin=vmin_e, @@ -1800,10 +1761,11 @@ def _visualize_last_iteration( ) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") + if self._object_type == "potential": - ax.set_title("Reconstructed electrostatic potential") + ax.set_title("Electrostatic potential") elif self._object_type == "complex": - ax.set_title("Reconstructed electrostatic phase") + ax.set_title("Electrostatic phase") if cbar: divider = make_axes_locatable(ax) @@ -1811,10 +1773,10 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - # Magnetic Object + # Object_e ax = fig.add_subplot(spec[0, 1]) im = ax.imshow( - rotated_magnetic, + obj[1], extent=extent, cmap=cmap_m, vmin=vmin_m, @@ -1823,7 +1785,11 @@ def _visualize_last_iteration( ) ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Reconstructed magnetic potential") + + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") if cbar: divider = make_axes_locatable(ax) @@ -1833,6 +1799,7 @@ def _visualize_last_iteration( if plot_convergence and hasattr(self, "error_iterations"): errors = np.array(self.error_iterations) + ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") @@ -1842,201 +1809,11 @@ def _visualize_last_iteration( fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - raise NotImplementedError() - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self - - @property - def self_consistency_errors(self): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Overlaps - _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap[0]) - - # Normalized mean-squared errors - error = xp.sum( - xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - error /= self._mean_diffraction_intensity - - return asnumpy(error) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[0][start:end] - - # Overlaps - _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap[0]) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped[0]) - else: - projected_cropped_potential = self.object_cropped[0] - - return projected_cropped_potential - @property def object_cropped(self): """Cropped and rotated object""" - obj_e, obj_m = self._object - obj_e = self._crop_rotate_object_fov(obj_e) - obj_m = self._crop_rotate_object_fov(obj_m) - return (obj_e, obj_m) + cropped_e = self._crop_rotate_object_fov(self._object[0]) + cropped_m = self._crop_rotate_object_fov(self._object[1]) + + return np.array([cropped_e, cropped_m]) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 391c45717..16fbc403f 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -984,13 +984,6 @@ def show_object_fft( **kwargs, ) - def _return_self_consistency_errors( - self, - **kwargs, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - @property def object_supersliced(self): """Returns super-sliced object""" @@ -1405,83 +1398,6 @@ def _return_single_probe(self, probe=None): return self._probe[0] -class ProbeListMethodsMixin: - """ - Mixin class for probe methods unique to a list of single probes. - Overwrites ProbeMethodsMixin. - """ - - def _reset_reconstruction( - self, - store_iterations, - reset, - use_projection_scheme, - ): - """ """ - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - # reset can be True, False, or None (default) - if reset is True: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probes_all = [pr.copy() for pr in self._probes_all_initial] - self._positions_px_all = self._positions_px_initial_all.copy() - self._object_type = self._object_type_initial - - if use_projection_scheme: - self._exit_waves = [None] * len(self._probes_all) - else: - self._exit_waves = None - - # delete positions affine transform - if hasattr(self, "_tf"): - del self._tf - - elif reset is None: - # continued run - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - - # first start - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * len(self._probes_all) - else: - self._exit_waves = None - - def _return_single_probe(self, probe=None): - """Current probe estimate""" - xp = self._xp - - if probe is not None: - _probes = [xp.asarray(pr) for pr in probe] - else: - if not hasattr(self, "_probes_all"): - return None - _probes = self._probes_all - - probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) - - for pr in _probes: - probe += pr - - return probe / len(_probes) - - @property - def _probe(self): - """Dummy property to make single-probe functions work""" - return self._return_single_probe() - - class ObjectNDProbeMethodsMixin: """ Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. @@ -3121,3 +3037,138 @@ def show_transmitted_probe( **kwargs, ): raise NotImplementedError() + + +class MultipleMeasurementsMethodsMixin: + """ + Mixin class for methods unique to classes with multiple measurements. + Overwrites various Mixins. + """ + + def _reset_reconstruction( + self, + store_iterations, + reset, + use_projection_scheme, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probes_all = [pr.copy() for pr in self._probes_all_initial] + self._positions_px_all = self._positions_px_initial_all.copy() + self._object_type = self._object_type_initial + + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + _probes = [xp.asarray(pr) for pr in probe] + else: + if not hasattr(self, "_probes_all"): + return None + _probes = self._probes_all + + probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) + + for pr in _probes: + probe += pr + + return probe / len(_probes) + + def _return_average_positions( + self, positions=None, cum_probes_per_measurement=None + ): + """Average positions estimate""" + xp = self._xp + + if positions is not None: + _pos = xp.asarray(positions) + else: + if not hasattr(self, "_positions_px_all"): + return None + _pos = self._positions_px_all + + if cum_probes_per_measurement is None: + cum_probes_per_measurement = self._cum_probes_per_measurement + + num_probes_per_measurement = np.diff(cum_probes_per_measurement) + num_measurements = len(num_probes_per_measurement) + + if np.any(num_probes_per_measurement != num_probes_per_measurement[0]): + return None + + avg_positions = xp.zeros((num_probes_per_measurement[0], 2), dtype=xp.float32) + + for index in range(num_measurements): + start_idx = cum_probes_per_measurement[index] + end_idx = cum_probes_per_measurement[index + 1] + avg_positions += _pos[start_idx:end_idx] + + return avg_positions / num_measurements + + def _return_self_consistency_errors( + self, + **kwargs, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + @property + def _probe(self): + """Dummy property to make single-probe functions work""" + return self._return_single_probe() + + @property + def positions(self): + """Probe positions [A]""" + + if self.angular_sampling is None: + return None + + asnumpy = self._asnumpy + positions_all = [] + + for index in range(self._num_measurements): + start_idx = self._cum_probes_per_measurement[index] + end_idx = self._cum_probes_per_measurement[index + 1] + positions = self._positions_px_all[start_idx:end_idx].copy() + positions[:, 0] *= self.sampling[0] + positions[:, 1] *= self.sampling[1] + positions_all.append(asnumpy(positions)) + + return np.asarray(positions_all) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index c4374c93a..590b907ea 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -27,12 +27,12 @@ ProbeConstraintsMixin, ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, Object3DMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, - ProbeListMethodsMixin, ProbeMethodsMixin, ) from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( @@ -56,9 +56,9 @@ class PtychographicTomographyReconstruction( Object3DConstraintsMixin, Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, Object2p5DProbeMethodsMixin, ObjectNDProbeMethodsMixin, - ProbeListMethodsMixin, ProbeMethodsMixin, Object3DMethodsMixin, Object2p5DMethodsMixin, @@ -226,7 +226,7 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) - self._num_tilts = num_tilts + self._num_measurements = num_tilts def preprocess( self, @@ -329,22 +329,22 @@ def preprocess( UserWarning, ) self._positions_mask = np.tile( - self._positions_mask, (self._num_tilts, 1, 1) + self._positions_mask, (self._num_measurements, 1, 1) ) - num_probes_per_tilt = np.insert( + num_probes_per_measurement = np.insert( self._positions_mask.sum(axis=(-2, -1)), 0, 0 ) else: - self._positions_mask = [None] * self._num_tilts - num_probes_per_tilt = [0] + [dc.R_N for dc in self._datacube] - num_probes_per_tilt = np.array(num_probes_per_tilt) + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) # prepopulate relevant arrays self._mean_diffraction_intensity = [] - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) # calculate roi_shape @@ -359,54 +359,54 @@ def preprocess( # TO-DO: generalize this if force_com_shifts is None: - force_com_shifts = [None] * self._num_tilts + force_com_shifts = [None] * self._num_measurements self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) self._rotation_best_transpose = diffraction_patterns_transpose # loop over DPs for preprocessing - for tilt_index in tqdmnd( - self._num_tilts, + for index in tqdmnd( + self._num_measurements, desc="Preprocessing data", unit="tilt", disable=not progress_bar, ): # preprocess datacube, vacuum and masks only for first tilt - if tilt_index == 0: + if index == 0: ( - self._datacube[tilt_index], + self._datacube[index], self._vacuum_probe_intensity, self._dp_mask, - force_com_shifts[tilt_index], + force_com_shifts[index], ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], + self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, probe_roi_shape=self._probe_roi_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, - com_shifts=force_com_shifts[tilt_index], + com_shifts=force_com_shifts[index], ) else: ( - self._datacube[tilt_index], + self._datacube[index], _, _, - force_com_shifts[tilt_index], + force_com_shifts[index], ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], + self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, probe_roi_shape=self._probe_roi_shape, vacuum_probe_intensity=None, dp_mask=None, - com_shifts=force_com_shifts[tilt_index], + com_shifts=force_com_shifts[index], ) # calibrations intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube[tilt_index], + self._datacube[index], require_calibrations=True, force_scan_sampling=force_scan_sampling, force_angular_sampling=force_angular_sampling, @@ -425,12 +425,12 @@ def preprocess( intensities, dp_mask=self._dp_mask, fit_function=fit_function, - com_shifts=force_com_shifts[tilt_index], + com_shifts=force_com_shifts[index], ) # corner-center amplitudes - idx_start = self._cum_probes_per_tilt[tilt_index] - idx_end = self._cum_probes_per_tilt[tilt_index + 1] + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] ( self._amplitudes[idx_start:idx_end], mean_diffraction_intensity_temp, @@ -439,7 +439,7 @@ def preprocess( intensities, com_fitted_x, com_fitted_y, - self._positions_mask[tilt_index], + self._positions_mask[index], crop_patterns, ) @@ -460,8 +460,8 @@ def preprocess( self._positions_px_all[idx_start:idx_end], self._object_padding_px, ) = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], - self._positions_mask[tilt_index], + self._scan_positions[index], + self._positions_mask[index], self._object_padding_px, ) @@ -487,9 +487,9 @@ def preprocess( # center probe positions self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) - for tilt_index in range(self._num_tilts): - idx_start = self._cum_probes_per_tilt[tilt_index] - idx_end = self._cum_probes_per_tilt[tilt_index + 1] + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] self._positions_px = self._positions_px_all[idx_start:idx_end] self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= ( @@ -508,11 +508,11 @@ def preprocess( self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) - for tilt_index in range(self._num_tilts): + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( - self._probe_init[tilt_index] if list_Q else self._probe_init, + self._probe_init[index] if list_Q else self._probe_init, self._vacuum_probe_intensity, - self._mean_diffraction_intensity[tilt_index], + self._mean_diffraction_intensity[index], self._semiangle_cutoff, crop_patterns, ) @@ -549,10 +549,10 @@ def preprocess( probe_overlap_3D = xp.zeros_like(self._object) old_rot_matrix = np.eye(3) # identity - for tilt_index in range(self._num_tilts): - idx_start = self._cum_probes_per_tilt[tilt_index] - idx_end = self._cum_probes_per_tilt[tilt_index + 1] - rot_matrix = self._tilt_orientation_matrices[tilt_index] + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + rot_matrix = self._tilt_orientation_matrices[index] probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, @@ -564,7 +564,7 @@ def preprocess( self._positions_px ) shifted_probes = fft_shift( - self._probes_all[tilt_index], self._positions_px_fractional, xp + self._probes_all[index], self._positions_px_fractional, xp ) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts( @@ -587,7 +587,9 @@ def preprocess( else: self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] + self._positions_px = self._positions_px_all[ + : self._cum_probes_per_measurement[1] + ] self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) @@ -738,7 +740,7 @@ def reconstruct( tv_denoise_iter=np.inf, tv_denoise_weights=None, tv_denoise_inner_iter=40, - collective_tilt_updates: bool = True, + collective_measurement_updates: bool = True, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, @@ -829,8 +831,8 @@ def reconstruct( the more denoising. tv_denoise_inner_iter: float Number of iterations to run in inner loop of TV denoising - collective_tilt_updates: bool - if True perform collective tilt updates + collective_measurement_updates: bool + if True perform collective measurement updates (i.e. one per tilt) shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -897,20 +899,22 @@ def reconstruct( ): error = 0.0 - if collective_tilt_updates: + if collective_measurement_updates: collective_object = xp.zeros_like(self._object) - tilt_indices = np.arange(self._num_tilts) - np.random.shuffle(tilt_indices) + indices = np.arange(self._num_measurements) + np.random.shuffle(indices) old_rot_matrix = np.eye(3) # identity - for tilt_index in tilt_indices: - self._active_tilt_index = tilt_index + for index in indices: + self._active_measurement_index = index - tilt_error = 0.0 + measurement_error = 0.0 - rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + rot_matrix = self._tilt_orientation_matrices[ + self._active_measurement_index + ] self._object = self._rotate_zxy_volume( self._object, rot_matrix @ old_rot_matrix.T, @@ -920,18 +924,22 @@ def reconstruct( self._object, self._num_slices ) - _probe = self._probes_all[self._active_tilt_index] + _probe = self._probes_all[self._active_measurement_index] _probe_initial_aperture = self._probes_all_initial_aperture[ - self._active_tilt_index + self._active_measurement_index ] if not use_projection_scheme: object_sliced_old = object_sliced.copy() - start_tilt = self._cum_probes_per_tilt[self._active_tilt_index] - end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1] + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] - num_diffraction_patterns = end_tilt - start_tilt + num_diffraction_patterns = end_idx - start_idx shuffled_indices = np.arange(num_diffraction_patterns) unshuffled_indices = np.zeros_like(shuffled_indices) @@ -943,11 +951,11 @@ def reconstruct( num_diffraction_patterns ) - positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[ + positions_px = self._positions_px_all[start_idx:end_idx].copy()[ shuffled_indices ] initial_positions_px = self._positions_px_initial_all[ - start_tilt:end_tilt + start_idx:end_idx ].copy()[shuffled_indices] for start, end in generate_batches( @@ -966,7 +974,7 @@ def reconstruct( self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start_tilt:end_tilt][ + amplitudes = self._amplitudes[start_idx:end_idx][ shuffled_indices[start:end] ] @@ -1015,7 +1023,7 @@ def reconstruct( max_position_total_distance, ) - tilt_error += batch_error + measurement_error += batch_error if not use_projection_scheme: object_sliced -= object_sliced_old @@ -1024,7 +1032,7 @@ def reconstruct( object_sliced, self._num_voxels ) - if collective_tilt_updates: + if collective_measurement_updates: collective_object += self._rotate_zxy_volume( object_update, rot_matrix.T ) @@ -1034,18 +1042,18 @@ def reconstruct( old_rot_matrix = rot_matrix # Normalize Error - tilt_error /= ( - self._mean_diffraction_intensity[self._active_tilt_index] + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] * num_diffraction_patterns ) - error += tilt_error + error += measurement_error # constraints - self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[ + self._positions_px_all[start_idx:end_idx] = positions_px.copy()[ unshuffled_indices ] - if collective_tilt_updates: + if collective_measurement_updates: # probe and positions _probe = self._probe_constraints( _probe, @@ -1068,9 +1076,9 @@ def reconstruct( ) self._positions_px_all[ - start_tilt:end_tilt + start_idx:end_idx ] = self._positions_constraints( - self._positions_px_all[start_tilt:end_tilt], + self._positions_px_all[start_idx:end_idx], fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, ) @@ -1080,11 +1088,11 @@ def reconstruct( ( self._object, _probe, - self._positions_px_all[start_tilt:end_tilt], + self._positions_px_all[start_idx:end_idx], ) = self._constraints( self._object, _probe, - self._positions_px_all[start_tilt:end_tilt], + self._positions_px_all[start_idx:end_idx], fix_com=fix_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, @@ -1126,10 +1134,10 @@ def reconstruct( self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) # Normalize Error Over Tilts - error /= self._num_tilts + error /= self._num_measurements - if collective_tilt_updates: - self._object += collective_object / self._num_tilts + if collective_measurement_updates: + self._object += collective_object / self._num_measurements # object only self._object = self._object_constraints( @@ -1174,24 +1182,3 @@ def reconstruct( xp.clear_memo() return self - - @property - def positions(self): - """Probe positions [A]""" - - if self.angular_sampling is None: - return None - - asnumpy = self._asnumpy - positions_all = [] - for tilt_index in range(self._num_tilts): - positions = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ].copy() - positions[:, 0] *= self.sampling[0] - positions[:, 1] *= self.sampling[1] - positions_all.append(asnumpy(positions)) - - return np.asarray(positions_all) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index 4208cfd96..79db7d16f 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -6,8 +6,11 @@ from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.process.phase.utils import AffineTransform -from py4DSTEM.visualize import return_scaled_histogram_ordering -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) try: import cupy as cp @@ -534,6 +537,8 @@ def visualize( def show_updated_positions( self, + pos=None, + initial_pos=None, scale_arrows=1, plot_arrow_freq=None, plot_cropped_rotated_fov=True, @@ -561,8 +566,15 @@ def show_updated_positions( asnumpy = self._asnumpy - initial_pos = asnumpy(self._positions_initial) - pos = self.positions + if pos is None: + pos = self.positions + + # handle multiple measurements + if pos.ndim == 3: + pos = pos.mean(0) + + if initial_pos is None: + initial_pos = asnumpy(self._positions_initial) if plot_cropped_rotated_fov: angle = ( From d54bf374323691e2048408e7130508f7c327113b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 18:18:19 -0800 Subject: [PATCH 053/128] phase unwrapping bugfixes Former-commit-id: 65448cb8651891b962e919812a32a8a3b74c9275 --- .../phase/iterative_magnetic_ptychography.py | 5 +++++ ...tive_mixedstate_multislice_ptychography.py | 4 ++++ .../iterative_mixedstate_ptychography.py | 4 ++++ .../iterative_multislice_ptychography.py | 4 ++++ .../iterative_ptychographic_constraints.py | 22 ++++++++++++++----- .../iterative_ptychographic_tomography.py | 5 +++++ .../iterative_singleslice_ptychography.py | 8 +++++-- py4DSTEM/process/phase/utils.py | 11 +++++++--- 8 files changed, 53 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py index 01817e5fa..3330e33d5 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -1067,6 +1067,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass_e: float = None, q_lowpass_m: float = None, @@ -1158,6 +1159,8 @@ def reconstruct( Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass_e: float @@ -1425,6 +1428,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, ) @@ -1461,6 +1465,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 79930c740..aa6f4fe56 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -679,6 +679,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -770,6 +771,8 @@ def reconstruct( Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -980,6 +983,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 765f6b8b6..a57236bc5 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -583,6 +583,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -669,6 +670,8 @@ def reconstruct( Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -864,6 +867,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 91f9ea076..cd622f5e7 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -653,6 +653,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -744,6 +745,8 @@ def reconstruct( Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -954,6 +957,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 5181984ed..8d8a5d158 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1037,19 +1037,21 @@ def _probe_aberration_fitting_constraint( current_probe, max_angular_order, max_radial_order, + remove_initial_probe_aberrations, ): """ Ptychographic probe smoothing constraint. - Removes/adds known (initialization) aberrations before/after smoothing. Parameters ---------- current_probe: np.ndarray Current positions estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - fix_amplitude: bool - If True, only the phase is smoothed + max_angular_order: bool + Max angular order of probe aberrations basis functions + max_radial_order: bool + Max radial order of probe aberrations basis functions + remove_initial_probe_aberrations: bool, optional + If true, initial probe aberrations are removed before fitting Returns -------- @@ -1060,6 +1062,9 @@ def _probe_aberration_fitting_constraint( xp = self._xp fourier_probe = xp.fft.fft2(current_probe) + if remove_initial_probe_aberrations: + fourier_probe *= xp.conj(self._known_aberrations_array) + fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling energy = self._energy @@ -1074,6 +1079,9 @@ def _probe_aberration_fitting_constraint( ) fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) + if remove_initial_probe_aberrations: + fourier_probe *= self._known_aberrations_array + current_probe = xp.fft.ifft2(fourier_probe) return current_probe @@ -1085,6 +1093,7 @@ def _probe_constraints( fit_probe_aberrations, fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, fix_probe_aperture, initial_probe_aperture, constrain_probe_fourier_amplitude, @@ -1107,6 +1116,7 @@ def _probe_constraints( current_probe, fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, ) # Fourier amplitude (aperture) constraints @@ -1208,6 +1218,7 @@ def _probe_constraints( fit_probe_aberrations, fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, fix_probe_aperture, initial_probe_aperture, constrain_probe_fourier_amplitude, @@ -1232,6 +1243,7 @@ def _probe_constraints( current_probe[probe_idx], fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, ) # Fourier amplitude (aperture) constraints diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index 590b907ea..cc62f9d3f 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -730,6 +730,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -814,6 +815,8 @@ def reconstruct( Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1071,6 +1074,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, ) @@ -1107,6 +1111,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 3a0f5d77b..b60157d35 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -561,6 +561,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -643,10 +644,12 @@ def reconstruct( Number of iterations to run using object smoothness constraint fit_probe_aberrations_iter: int, optional Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool + fit_probe_aberrations_max_angular_order: int Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool + fit_probe_aberrations_max_radial_order: int Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -842,6 +845,7 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 95dbc4511..90775dd64 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1657,13 +1657,16 @@ def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np dy = xp.mod(xp.diff(array, axis=1) + np.pi, 2 * np.pi) - np.pi if weights is not None: + # normalize weights + weights -= weights.min() + weights /= weights.max() + ww = weights**2 dx *= xp.minimum(ww[:-1, :], ww[1:, :]) dy *= xp.minimum(ww[:, :-1], ww[:, 1:]) - rho = xp.diff(dx, axis=0, prepend=0, append=0) + xp.diff( - dy, axis=1, prepend=0, append=0 - ) + rho = xp.diff(dx, axis=0, prepend=0, append=0) + rho += xp.diff(dy, axis=1, prepend=0, append=0) unwrapped_array = preconditioned_poisson_solver_dct(rho, gauge=gauge, xp=xp).real unwrapped_array -= unwrapped_array.min() @@ -1709,6 +1712,8 @@ def fit_aberration_surface( coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) + angle_offset = fitted_angle[0, 0] - probe_angle[0, 0] + fitted_angle -= angle_offset return fitted_angle, coeff From 996da615cef886f9f8d64f38b1d3ed8256113a12 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 12:50:14 -0800 Subject: [PATCH 054/128] better support for multiple measurements probe properties Former-commit-id: b4eb43cd06436cd78e044ca779c7d5795f3ec32e --- .../phase/iterative_magnetic_ptychography.py | 9 ++--- .../phase/iterative_ptychographic_methods.py | 33 +++++++++++++++++-- .../iterative_ptychographic_tomography.py | 9 ++--- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py index 3330e33d5..4868f7917 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -626,7 +626,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - self.probe_centered, + self.probe_centered[0], power=power, chroma_boost=chroma_boost, ) @@ -1532,12 +1532,7 @@ def reconstruct( if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append( - [ - asnumpy(self._return_centered_probe(pr.copy())) - for pr in self._probes_all - ] - ) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 16fbc403f..4bce5e5a9 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -3149,9 +3149,36 @@ def _return_self_consistency_errors( raise NotImplementedError() @property - def _probe(self): - """Dummy property to make single-probe functions work""" - return self._return_single_probe() + def probe_fourier(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [asnumpy(self._return_fourier_probe(pr)) for pr in self._probes_all] + + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return [ + asnumpy( + self._return_fourier_probe(pr, remove_initial_probe_aberrations=True) + ) + for pr in self._probes_all + ] + + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [asnumpy(self._return_centered_probe(pr)) for pr in self._probes_all] @property def positions(self): diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index cc62f9d3f..a76095d1e 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -608,7 +608,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - self.probe_centered, + self.probe_centered[0], power=power, chroma_boost=chroma_boost, ) @@ -1170,12 +1170,7 @@ def reconstruct( if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append( - [ - asnumpy(self._return_centered_probe(pr.copy())) - for pr in self._probes_all - ] - ) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) From f0dd0470af724fb57435354623850a4618c3ad93 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 14:00:05 -0800 Subject: [PATCH 055/128] generalized show_fourier_probe, added general show_probe, added intensity reporting Former-commit-id: a442fd0ba048cf6c7fd148f2ffbb71cb218743d4 --- .../phase/iterative_ptychographic_methods.py | 195 +++++++++++------- py4DSTEM/visualize/vis_special.py | 11 +- 2 files changed, 126 insertions(+), 80 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 4bce5e5a9..f1469007c 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -1162,6 +1162,90 @@ def _return_centered_probe( return xp.fft.fftshift(probe, axes=(-2, -1)) + def _return_probe_intensities(self, probe): + """ + Returns probe intensities summing up to 1. + """ + if probe is None: + probe = self.probe_centered + + intensity_arrays = np.abs(np.array(probe, ndmin=3)) ** 2 + probe_ratio = list(intensity_arrays.sum((-2, -1)) / intensity_arrays.sum()) + + return probe_ratio + + def show_probe( + self, + probe=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot probe in real space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if pixelsize is None: + pixelsize = self.sampling[1] + if pixelunits is None: + pixelunits = r"$\AA$" + + intensities = self._return_probe_intensities(probe) + title = [ + f"Probe {iter} intensity: {ratio*100:.1f}%" + for iter, ratio in enumerate(intensities) + ] + + axsize = kwargs.pop("axsize", (4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + ticks = kwargs.pop("ticks", False) + title = kwargs.pop("title", title if len(title) > 1 else title[0]) + + if probe is None: + probe = list(np.array(self.probe_centered, ndmin=3)) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_centered_probe( + pr, + ) + ) + for pr in probe + ] + + show_complex( + probe, + cbar=cbar, + axsize=axsize, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=ticks, + chroma_boost=chroma_boost, + title=title, + **kwargs, + ) + def show_fourier_probe( self, probe=None, @@ -1192,30 +1276,51 @@ def show_fourier_probe( """ asnumpy = self._asnumpy - probe = asnumpy( - self._return_fourier_probe( - probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations - ) - ) - if pixelsize is None: pixelsize = self._reciprocal_sampling[1] if pixelunits is None: pixelunits = r"$\AA^{-1}$" - figsize = kwargs.pop("figsize", (6, 6)) + intensities = self._return_probe_intensities(probe) + title = [ + f"Probe {iter} intensity: {ratio*100:.1f}%" + for iter, ratio in enumerate(intensities) + ] + + axsize = kwargs.pop("axsize", (4, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + ticks = kwargs.pop("ticks", False) + title = kwargs.pop("title", title if len(title) > 1 else title[0]) + + if probe is None: + if remove_initial_probe_aberrations: + probe = self.probe_fourier_residual + else: + probe = self.probe_fourier + probe = list(np.array(probe, ndmin=3)) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] - fig, ax = plt.subplots(figsize=figsize) show_complex( probe, cbar=cbar, - figax=(fig, ax), + axsize=axsize, scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, - ticks=False, + ticks=ticks, chroma_boost=chroma_boost, + title=title, **kwargs, ) @@ -1317,74 +1422,6 @@ def _initialize_probe( return _probes, semiangle_cutoff - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - def _return_single_probe(self, probe=None): """Current probe estimate""" xp = self._xp @@ -3160,7 +3197,7 @@ def probe_fourier(self): @property def probe_fourier_residual(self): """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): + if not hasattr(self, "_probes_all"): return None asnumpy = self._asnumpy diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 125a2ce67..c2dd2d6f4 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -809,8 +809,17 @@ def show_complex( add_scalebar(ax[0, 0], scalebar) else: + figsize = kwargs.pop("axsize", None) + figsize = kwargs.pop("figsize", figsize) + fig, ax = show( - rgb, vmin=0, vmax=1, intensity_range="absolute", returnfig=True, **kwargs + rgb, + vmin=0, + vmax=1, + intensity_range="absolute", + returnfig=True, + figsize=figsize, + **kwargs, ) if scalebar is True: From f619fb453bd59f91e39a998ee74c2f8058f056f8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 14:14:44 -0800 Subject: [PATCH 056/128] visual tweaks Former-commit-id: bcd6f26f4abf6c0498c618d6b2071ac7c7f20a52 --- .../phase/iterative_ptychographic_methods.py | 14 +++++++------- .../iterative_ptychographic_visualizations.py | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index f1469007c..d2a63ce9a 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -2407,9 +2407,9 @@ def show_transmitted_probe( ] ] title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", + "Mean transmitted probe", + "Min-intensity transmitted probe", + "Max-intensity transmitted probe", ] if plot_fourier_probe: @@ -2429,14 +2429,14 @@ def show_transmitted_probe( probes = [probes, bottom_row] title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", + "Mean transmitted Fourier probe", + "Min-intensity transmitted Fourier probe", + "Max-intensity transmitted Fourier probe", ] title = kwargs.get("title", title) ticks = kwargs.get("ticks", False) - axsize = kwargs.get("axsize", (4.5, 4.5)) + axsize = kwargs.get("axsize", (4, 4)) show_complex( probes, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py index 79db7d16f..916ef0352 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py @@ -615,7 +615,7 @@ def show_updated_positions( 0, ] - figsize = kwargs.pop("figsize", (6, 6)) + figsize = kwargs.pop("figsize", (4, 4)) cmap = kwargs.pop("cmap", "Reds") fig, ax = plt.subplots(figsize=figsize) @@ -767,13 +767,13 @@ def show_uncertainty_visualization( height_ratios=[1, 4], hspace=0.15, ) - auto_figsize = (6, 8) + auto_figsize = (4, 5) else: spec = GridSpec( ncols=1, nrows=1, ) - auto_figsize = (6, 6) + auto_figsize = (4, 4) figsize = kwargs.pop("figsize", auto_figsize) @@ -785,7 +785,7 @@ def show_uncertainty_visualization( counts, bins = np.histogram(errors, bins=50) ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) ax_hist.set_ylabel("Counts") - ax_hist.set_xlabel("Normalized Squared Error") + ax_hist.set_xlabel("Normalized squared error") ax = fig.add_subplot(spec[-1]) From b6d188e2ebbeb4f84dbf364f9e9ae3cb53d07c83 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 16:20:21 -0800 Subject: [PATCH 057/128] removed leading iterative_ from file names and trailing Reconstruction from class names Former-commit-id: 9114df505727362c515c5115860c9deb42dbaa1c --- py4DSTEM/process/phase/__init__.py | 18 +++++++++--------- .../process/phase/{iterative_dpc.py => dpc.py} | 4 ++-- ...py => magnetic_ptychographic_tomography.py} | 8 ++++---- ...tychography.py => magnetic_ptychography.py} | 12 +++++------- ...y => mixedstate_multislice_ptychography.py} | 12 +++++------- ...chography.py => mixedstate_ptychography.py} | 12 +++++------- ...chography.py => multislice_ptychography.py} | 12 +++++------- .../{iterative_parallax.py => parallax.py} | 4 ++-- py4DSTEM/process/phase/parameter_optimize.py | 2 +- ...ative_base_class.py => phase_base_class.py} | 0 ...traints.py => ptychographic_constraints.py} | 0 ...hic_methods.py => ptychographic_methods.py} | 0 ...mography.py => ptychographic_tomography.py} | 12 +++++------- ...ions.py => ptychographic_visualizations.py} | 0 ...hography.py => singleslice_ptychography.py} | 12 +++++------- 15 files changed, 48 insertions(+), 60 deletions(-) rename py4DSTEM/process/phase/{iterative_dpc.py => dpc.py} (99%) rename py4DSTEM/process/phase/{iterative_magnetic_ptychographic_tomography.py => magnetic_ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{iterative_magnetic_ptychography.py => magnetic_ptychography.py} (99%) rename py4DSTEM/process/phase/{iterative_mixedstate_multislice_ptychography.py => mixedstate_multislice_ptychography.py} (99%) rename py4DSTEM/process/phase/{iterative_mixedstate_ptychography.py => mixedstate_ptychography.py} (98%) rename py4DSTEM/process/phase/{iterative_multislice_ptychography.py => multislice_ptychography.py} (99%) rename py4DSTEM/process/phase/{iterative_parallax.py => parallax.py} (99%) rename py4DSTEM/process/phase/{iterative_base_class.py => phase_base_class.py} (100%) rename py4DSTEM/process/phase/{iterative_ptychographic_constraints.py => ptychographic_constraints.py} (100%) rename py4DSTEM/process/phase/{iterative_ptychographic_methods.py => ptychographic_methods.py} (100%) rename py4DSTEM/process/phase/{iterative_ptychographic_tomography.py => ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{iterative_ptychographic_visualizations.py => ptychographic_visualizations.py} (100%) rename py4DSTEM/process/phase/{iterative_singleslice_ptychography.py => singleslice_ptychography.py} (98%) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 2069ffebf..59da42559 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,15 +2,15 @@ _emd_hook = True -from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_magnetic_ptychographic_tomography import MagneticPtychographicTomographyReconstruction -from py4DSTEM.process.phase.iterative_magnetic_ptychography import MagneticPtychographicReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction -from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_tomography import PtychographicTomographyReconstruction -from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.dpc import DPC +from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography +from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography +from py4DSTEM.process.phase.mixedstate_multislice_ptychography import MixedstateMultislicePtychography +from py4DSTEM.process.phase.mixedstate_ptychography import MixedstatePtychography +from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography +from py4DSTEM.process.phase.parallax import Parallax +from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychographic from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/dpc.py similarity index 99% rename from py4DSTEM/process/phase/iterative_dpc.py rename to py4DSTEM/process/phase/dpc.py index 11adc0c70..7043afc9b 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -19,12 +19,12 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction warnings.simplefilter(action="always", category=UserWarning) -class DPCReconstruction(PhaseReconstruction): +class DPC(PhaseReconstruction): """ Iterative Differential Phase Constrast Reconstruction Class. diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py rename to py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 816f7185b..090201897 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -19,14 +19,14 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( Object3DMethodsMixin, ObjectNDMethodsMixin, ProbeMethodsMixin, @@ -45,7 +45,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MagneticPtychographicTomographyReconstruction( +class MagneticPtychographicTomography( PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_magnetic_ptychography.py rename to py4DSTEM/process/phase/magnetic_ptychography.py index 4868f7917..0e633ae71 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -23,21 +23,19 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( MultipleMeasurementsMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,7 +47,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MagneticPtychographicReconstruction( +class MagneticPtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py rename to py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index aa6f4fe56..931985224 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -18,15 +18,15 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( Object2p5DMethodsMixin, Object2p5DProbeMixedMethodsMixin, ObjectNDMethodsMixin, @@ -35,9 +35,7 @@ ProbeMethodsMixin, ProbeMixedMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,7 +47,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstateMultislicePtychographicReconstruction( +class MixedstateMultislicePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py similarity index 98% rename from py4DSTEM/process/phase/iterative_mixedstate_ptychography.py rename to py4DSTEM/process/phase/mixedstate_ptychography.py index a57236bc5..dbcb62e97 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -18,23 +18,21 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ObjectNDProbeMixedMethodsMixin, ProbeMethodsMixin, ProbeMixedMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -46,7 +44,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstatePtychographicReconstruction( +class MixedstatePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_multislice_ptychography.py rename to py4DSTEM/process/phase/multislice_ptychography.py index cd622f5e7..69ad11330 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -18,23 +18,21 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -46,7 +44,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MultislicePtychographicReconstruction( +class MultislicePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/parallax.py similarity index 99% rename from py4DSTEM/process/phase/iterative_parallax.py rename to py4DSTEM/process/phase/parallax.py index a7251c9c7..6f29d9ca1 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -14,7 +14,7 @@ from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import ( AffineTransform, bilinear_kernel_density_estimate, @@ -56,7 +56,7 @@ } -class ParallaxReconstruction(PhaseReconstruction): +class Parallax(PhaseReconstruction): """ Iterative parallax reconstruction class. diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 8744ec792..ff9982b44 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform from skopt import gp_minimize from skopt.plots import plot_convergence as skopt_plot_convergence diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/phase_base_class.py similarity index 100% rename from py4DSTEM/process/phase/iterative_base_class.py rename to py4DSTEM/process/phase/phase_base_class.py diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py similarity index 100% rename from py4DSTEM/process/phase/iterative_ptychographic_constraints.py rename to py4DSTEM/process/phase/ptychographic_constraints.py diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py similarity index 100% rename from py4DSTEM/process/phase/iterative_ptychographic_methods.py rename to py4DSTEM/process/phase/ptychographic_methods.py diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_ptychographic_tomography.py rename to py4DSTEM/process/phase/ptychographic_tomography.py index a76095d1e..1e1cb62b7 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -18,15 +18,15 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object2p5DConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( MultipleMeasurementsMethodsMixin, Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, @@ -35,9 +35,7 @@ ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,7 +47,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class PtychographicTomographyReconstruction( +class PtychographicTomography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py similarity index 100% rename from py4DSTEM/process/phase/iterative_ptychographic_visualizations.py rename to py4DSTEM/process/phase/ptychographic_visualizations.py diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py similarity index 98% rename from py4DSTEM/process/phase/iterative_singleslice_ptychography.py rename to py4DSTEM/process/phase/singleslice_ptychography.py index b60157d35..57b8ec93c 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -18,20 +18,18 @@ from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -43,7 +41,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class SingleslicePtychographicReconstruction( +class SingleslicePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, From aafd8f79967517e19552e26bf9c2b8b7443a5ae8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 16:23:09 -0800 Subject: [PATCH 058/128] typo Former-commit-id: 5990942213ceb4ee0b8491c19a3c7dc06985c88f --- py4DSTEM/process/phase/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 59da42559..ecfeaa1d2 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -10,7 +10,7 @@ from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography from py4DSTEM.process.phase.parallax import Parallax from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography -from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychographic +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on From 66ad5ae3d52d5cf7263241c5896587b3b33b82fc Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 17:04:19 -0800 Subject: [PATCH 059/128] FFT plotting improvements Former-commit-id: 6669195c0d439d9d41fee6d8c6a6fc0860f0091e --- .../process/phase/ptychographic_methods.py | 161 ++++++++++-------- .../process/phase/ptychographic_tomography.py | 2 +- 2 files changed, 95 insertions(+), 68 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index d2a63ce9a..b809c3d1b 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -129,6 +129,8 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, + apply_hanning_window=False, + **kwargs, ): """ Returns absolute value of obj fft shifted to center of array @@ -137,6 +139,8 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT Returns ------- @@ -153,34 +157,82 @@ def _return_object_fft( obj = xp.angle(obj) obj = self._crop_rotate_object_fov(asnumpy(obj)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) - def show_object_fft(self, obj=None, **kwargs): + def show_object_fft( + self, + obj=None, + apply_hanning_window=True, + crop_to_min_frequency=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): """ Plot FFT of reconstructed object Parameters ---------- obj: complex array, optional - if None is specified, uses the `object_fft` property + If None is specified, uses the `object_fft` property + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + crop_to_min_frequency: bool, optional + If True, a square FFT is plotted, cropping to the smallest axis + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is object FFT sampling """ - if obj is None: - object_fft = self.object_fft - else: - object_fft = self._return_object_fft(obj) - figsize = kwargs.pop("figsize", (6, 6)) + object_fft = self._return_object_fft( + obj, apply_hanning_window=apply_hanning_window, **kwargs + ) + + if pixelsize is None: + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + if crop_to_min_frequency: + sx, sy = object_fft.shape + s = min(sx, sy) + start_x = sx // 2 - (s // 2) + start_y = sy // 2 - (s // 2) + object_fft = object_fft[start_x : start_x + s, start_y : start_y + s] + + figsize = kwargs.pop("figsize", (4, 4)) cmap = kwargs.pop("cmap", "magma") + ticks = kwargs.pop("ticks", False) + vmin = kwargs.pop("vmin", 0.001) + vmax = kwargs.pop("vmax", 0.999) + + # remove additional 3D FFT parameters before passing to show + kwargs.pop("projection_angle_deg", None) + kwargs.pop("projection_axes", None) + kwargs.pop("x_lims", None) + kwargs.pop("y_lims", None) - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - scalebar=True, + scalebar=scalebar, pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", + ticks=ticks, + pixelunits=pixelunits, + vmin=vmin, + vmax=vmax, **kwargs, ) @@ -391,6 +443,8 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, + apply_hanning_window=False, + **kwargs, ): """ Returns obj fft shifted to center of array @@ -399,6 +453,13 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. """ xp = self._xp @@ -409,6 +470,13 @@ def _return_object_fft( obj = xp.angle(obj) obj = self._crop_rotate_object_fov(obj.sum(axis=0)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) def show_depth_section( @@ -880,10 +948,12 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, + apply_hanning_window=False, projection_angle_deg: float = None, projection_axes: Tuple[int, int] = (0, 2), x_lims: Tuple[int, int] = (None, None), y_lims: Tuple[int, int] = (None, None), + **kwargs, ): """ Returns obj fft shifted to center of array @@ -892,6 +962,8 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT projection_angle_deg: float Angle in degrees to rotate 3D array around prior to projection projection_axes: tuple(int,int) @@ -900,6 +972,11 @@ def _return_object_fft( min/max x indices y_lims: tuple(float,float) min/max y indices + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. """ xp = self._xp @@ -926,63 +1003,13 @@ def _return_object_fft( rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims ) - return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object + if apply_hanning_window: + sx, sy = rotated_object.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + rotated_object *= wx[:, None] * wy[None, :] - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) + return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) @property def object_supersliced(self): diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 1e1cb62b7..165dd115e 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -612,7 +612,7 @@ def preprocess( ) # propagated - propagated_probe = self._probe.copy() + propagated_probe = self._probes_all[0].copy() for s in range(self._num_slices - 1): propagated_probe = self._propagate_array( From 5e416c39af09546f0923d4abe86bc068b00aa4d9 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 17:09:01 -0800 Subject: [PATCH 060/128] Revert "FFT plotting improvements" This reverts commit 66ad5ae3d52d5cf7263241c5896587b3b33b82fc [formerly 6669195c0d439d9d41fee6d8c6a6fc0860f0091e]. Former-commit-id: 2c0f60acf5210c0bfc966897fe2fe08e56f00898 --- .../process/phase/ptychographic_methods.py | 161 ++++++++---------- .../process/phase/ptychographic_tomography.py | 2 +- 2 files changed, 68 insertions(+), 95 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index b809c3d1b..d2a63ce9a 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -129,8 +129,6 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, - apply_hanning_window=False, - **kwargs, ): """ Returns absolute value of obj fft shifted to center of array @@ -139,8 +137,6 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object - apply_hanning_window: bool, optional - If True, a 2D Hann window is applied to the object before FFT Returns ------- @@ -157,82 +153,34 @@ def _return_object_fft( obj = xp.angle(obj) obj = self._crop_rotate_object_fov(asnumpy(obj)) - - if apply_hanning_window: - sx, sy = obj.shape - wx = np.hanning(sx) - wy = np.hanning(sy) - obj *= wx[:, None] * wy[None, :] - return np.abs(np.fft.fftshift(np.fft.fft2(obj))) - def show_object_fft( - self, - obj=None, - apply_hanning_window=True, - crop_to_min_frequency=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): + def show_object_fft(self, obj=None, **kwargs): """ Plot FFT of reconstructed object Parameters ---------- obj: complex array, optional - If None is specified, uses the `object_fft` property - apply_hanning_window: bool, optional - If True, a 2D Hann window is applied to the object before FFT - crop_to_min_frequency: bool, optional - If True, a square FFT is plotted, cropping to the smallest axis - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is object FFT sampling + if None is specified, uses the `object_fft` property """ + if obj is None: + object_fft = self.object_fft + else: + object_fft = self._return_object_fft(obj) - object_fft = self._return_object_fft( - obj, apply_hanning_window=apply_hanning_window, **kwargs - ) - - if pixelsize is None: - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - if crop_to_min_frequency: - sx, sy = object_fft.shape - s = min(sx, sy) - start_x = sx // 2 - (s // 2) - start_y = sy // 2 - (s // 2) - object_fft = object_fft[start_x : start_x + s, start_y : start_y + s] - - figsize = kwargs.pop("figsize", (4, 4)) + figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - ticks = kwargs.pop("ticks", False) - vmin = kwargs.pop("vmin", 0.001) - vmax = kwargs.pop("vmax", 0.999) - - # remove additional 3D FFT parameters before passing to show - kwargs.pop("projection_angle_deg", None) - kwargs.pop("projection_axes", None) - kwargs.pop("x_lims", None) - kwargs.pop("y_lims", None) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - scalebar=scalebar, + scalebar=True, pixelsize=pixelsize, - ticks=ticks, - pixelunits=pixelunits, - vmin=vmin, - vmax=vmax, + ticks=False, + pixelunits=r"$\AA^{-1}$", **kwargs, ) @@ -443,8 +391,6 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, - apply_hanning_window=False, - **kwargs, ): """ Returns obj fft shifted to center of array @@ -453,13 +399,6 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object - apply_hanning_window: bool, optional - If True, a 2D Hann window is applied to the object before FFT - - Returns - ------- - object_fft_amplitude: np.ndarray - Amplitude of Fourier-transformed and center-shifted obj. """ xp = self._xp @@ -470,13 +409,6 @@ def _return_object_fft( obj = xp.angle(obj) obj = self._crop_rotate_object_fov(obj.sum(axis=0)) - - if apply_hanning_window: - sx, sy = obj.shape - wx = np.hanning(sx) - wy = np.hanning(sy) - obj *= wx[:, None] * wy[None, :] - return np.abs(np.fft.fftshift(np.fft.fft2(obj))) def show_depth_section( @@ -948,12 +880,10 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, - apply_hanning_window=False, projection_angle_deg: float = None, projection_axes: Tuple[int, int] = (0, 2), x_lims: Tuple[int, int] = (None, None), y_lims: Tuple[int, int] = (None, None), - **kwargs, ): """ Returns obj fft shifted to center of array @@ -962,8 +892,6 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object - apply_hanning_window: bool, optional - If True, a 2D Hann window is applied to the object before FFT projection_angle_deg: float Angle in degrees to rotate 3D array around prior to projection projection_axes: tuple(int,int) @@ -972,11 +900,6 @@ def _return_object_fft( min/max x indices y_lims: tuple(float,float) min/max y indices - - Returns - ------- - object_fft_amplitude: np.ndarray - Amplitude of Fourier-transformed and center-shifted obj. """ xp = self._xp @@ -1003,14 +926,64 @@ def _return_object_fft( rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims ) - if apply_hanning_window: - sx, sy = rotated_object.shape - wx = np.hanning(sx) - wy = np.hanning(sy) - rotated_object *= wx[:, None] * wy[None, :] - return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) + def show_object_fft( + self, + obj=None, + projection_angle_deg: float = None, + projection_axes: Tuple[int, int] = (0, 2), + x_lims: Tuple[int, int] = (None, None), + y_lims: Tuple[int, int] = (None, None), + **kwargs, + ): + """ + Plot FFT of reconstructed object + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + projection_angle_deg: float + Angle in degrees to rotate 3D array around prior to projection + projection_axes: tuple(int,int) + Axes defining projection plane + x_lims: tuple(float,float) + min/max x indices + y_lims: tuple(float,float) + min/max y indices + """ + if obj is None: + object_fft = self._return_object_fft( + projection_angle_deg=projection_angle_deg, + projection_axes=projection_axes, + x_lims=x_lims, + y_lims=y_lims, + ) + else: + object_fft = self._return_object_fft( + obj, + projection_angle_deg=projection_angle_deg, + projection_axes=projection_axes, + x_lims=x_lims, + y_lims=y_lims, + ) + + figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "magma") + + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + show( + object_fft, + figsize=figsize, + cmap=cmap, + scalebar=True, + pixelsize=pixelsize, + ticks=False, + pixelunits=r"$\AA^{-1}$", + **kwargs, + ) + @property def object_supersliced(self): """Returns super-sliced object""" diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 165dd115e..1e1cb62b7 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -612,7 +612,7 @@ def preprocess( ) # propagated - propagated_probe = self._probes_all[0].copy() + propagated_probe = self._probe.copy() for s in range(self._num_slices - 1): propagated_probe = self._propagate_array( From ed4faf6e0ac05eb1d5f174403463ee21efd5865c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 17:09:17 -0800 Subject: [PATCH 061/128] Revert "typo" This reverts commit aafd8f79967517e19552e26bf9c2b8b7443a5ae8 [formerly 5990942213ceb4ee0b8491c19a3c7dc06985c88f]. Former-commit-id: bb8b70fae5d08ff85c1021564d9947acc60c3238 --- py4DSTEM/process/phase/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index ecfeaa1d2..59da42559 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -10,7 +10,7 @@ from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography from py4DSTEM.process.phase.parallax import Parallax from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography -from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychographic from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on From e5a86a54e03fba5a952b9cffc0f88fbe1a4974a7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 17:10:32 -0800 Subject: [PATCH 062/128] Revert "removed leading iterative_ from file names and trailing Reconstruction from class names" This reverts commit b6d188e2ebbeb4f84dbf364f9e9ae3cb53d07c83 [formerly 9114df505727362c515c5115860c9deb42dbaa1c]. Former-commit-id: 923ae77241dd83fbc3805ccacf42d872ad22567a --- py4DSTEM/process/phase/__init__.py | 18 +++++++++--------- ...e_base_class.py => iterative_base_class.py} | 0 .../process/phase/{dpc.py => iterative_dpc.py} | 4 ++-- ...ative_magnetic_ptychographic_tomography.py} | 8 ++++---- ...y.py => iterative_magnetic_ptychography.py} | 12 +++++++----- ...tive_mixedstate_multislice_ptychography.py} | 12 +++++++----- ...py => iterative_mixedstate_ptychography.py} | 12 +++++++----- ...py => iterative_multislice_ptychography.py} | 12 +++++++----- .../{parallax.py => iterative_parallax.py} | 4 ++-- ... => iterative_ptychographic_constraints.py} | 0 ...s.py => iterative_ptychographic_methods.py} | 0 ...y => iterative_ptychographic_tomography.py} | 12 +++++++----- ... iterative_ptychographic_visualizations.py} | 0 ...y => iterative_singleslice_ptychography.py} | 12 +++++++----- py4DSTEM/process/phase/parameter_optimize.py | 2 +- 15 files changed, 60 insertions(+), 48 deletions(-) rename py4DSTEM/process/phase/{phase_base_class.py => iterative_base_class.py} (100%) rename py4DSTEM/process/phase/{dpc.py => iterative_dpc.py} (99%) rename py4DSTEM/process/phase/{magnetic_ptychographic_tomography.py => iterative_magnetic_ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{magnetic_ptychography.py => iterative_magnetic_ptychography.py} (99%) rename py4DSTEM/process/phase/{mixedstate_multislice_ptychography.py => iterative_mixedstate_multislice_ptychography.py} (99%) rename py4DSTEM/process/phase/{mixedstate_ptychography.py => iterative_mixedstate_ptychography.py} (98%) rename py4DSTEM/process/phase/{multislice_ptychography.py => iterative_multislice_ptychography.py} (99%) rename py4DSTEM/process/phase/{parallax.py => iterative_parallax.py} (99%) rename py4DSTEM/process/phase/{ptychographic_constraints.py => iterative_ptychographic_constraints.py} (100%) rename py4DSTEM/process/phase/{ptychographic_methods.py => iterative_ptychographic_methods.py} (100%) rename py4DSTEM/process/phase/{ptychographic_tomography.py => iterative_ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{ptychographic_visualizations.py => iterative_ptychographic_visualizations.py} (100%) rename py4DSTEM/process/phase/{singleslice_ptychography.py => iterative_singleslice_ptychography.py} (98%) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 59da42559..2069ffebf 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,15 +2,15 @@ _emd_hook = True -from py4DSTEM.process.phase.dpc import DPC -from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography -from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography -from py4DSTEM.process.phase.mixedstate_multislice_ptychography import MixedstateMultislicePtychography -from py4DSTEM.process.phase.mixedstate_ptychography import MixedstatePtychography -from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography -from py4DSTEM.process.phase.parallax import Parallax -from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography -from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychographic +from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction +from py4DSTEM.process.phase.iterative_magnetic_ptychographic_tomography import MagneticPtychographicTomographyReconstruction +from py4DSTEM.process.phase.iterative_magnetic_ptychography import MagneticPtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_tomography import PtychographicTomographyReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py similarity index 100% rename from py4DSTEM/process/phase/phase_base_class.py rename to py4DSTEM/process/phase/iterative_base_class.py diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/iterative_dpc.py similarity index 99% rename from py4DSTEM/process/phase/dpc.py rename to py4DSTEM/process/phase/iterative_dpc.py index 7043afc9b..11adc0c70 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -19,12 +19,12 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction warnings.simplefilter(action="always", category=UserWarning) -class DPC(PhaseReconstruction): +class DPCReconstruction(PhaseReconstruction): """ Iterative Differential Phase Constrast Reconstruction Class. diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/magnetic_ptychographic_tomography.py rename to py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py index 090201897..816f7185b 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py @@ -19,14 +19,14 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( Object3DMethodsMixin, ObjectNDMethodsMixin, ProbeMethodsMixin, @@ -45,7 +45,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MagneticPtychographicTomography( +class MagneticPtychographicTomographyReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/magnetic_ptychography.py rename to py4DSTEM/process/phase/iterative_magnetic_ptychography.py index 0e633ae71..4868f7917 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -23,19 +23,21 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( MultipleMeasurementsMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -47,7 +49,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MagneticPtychography( +class MagneticPtychographicReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/mixedstate_multislice_ptychography.py rename to py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 931985224..aa6f4fe56 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -18,15 +18,15 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( Object2p5DMethodsMixin, Object2p5DProbeMixedMethodsMixin, ObjectNDMethodsMixin, @@ -35,7 +35,9 @@ ProbeMethodsMixin, ProbeMixedMethodsMixin, ) -from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -47,7 +49,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstateMultislicePtychography( +class MixedstateMultislicePtychographicReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py similarity index 98% rename from py4DSTEM/process/phase/mixedstate_ptychography.py rename to py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index dbcb62e97..a57236bc5 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -18,21 +18,23 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ObjectNDProbeMixedMethodsMixin, ProbeMethodsMixin, ProbeMixedMethodsMixin, ) -from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -44,7 +46,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstatePtychography( +class MixedstatePtychographicReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/multislice_ptychography.py rename to py4DSTEM/process/phase/iterative_multislice_ptychography.py index 69ad11330..cd622f5e7 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -18,21 +18,23 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -44,7 +46,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MultislicePtychography( +class MultislicePtychographicReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/iterative_parallax.py similarity index 99% rename from py4DSTEM/process/phase/parallax.py rename to py4DSTEM/process/phase/iterative_parallax.py index 6f29d9ca1..a7251c9c7 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -14,7 +14,7 @@ from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar -from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import ( AffineTransform, bilinear_kernel_density_estimate, @@ -56,7 +56,7 @@ } -class Parallax(PhaseReconstruction): +class ParallaxReconstruction(PhaseReconstruction): """ Iterative parallax reconstruction class. diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py similarity index 100% rename from py4DSTEM/process/phase/ptychographic_constraints.py rename to py4DSTEM/process/phase/iterative_ptychographic_constraints.py diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py similarity index 100% rename from py4DSTEM/process/phase/ptychographic_methods.py rename to py4DSTEM/process/phase/iterative_ptychographic_methods.py diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/ptychographic_tomography.py rename to py4DSTEM/process/phase/iterative_ptychographic_tomography.py index 1e1cb62b7..a76095d1e 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -18,15 +18,15 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( Object2p5DConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( MultipleMeasurementsMethodsMixin, Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, @@ -35,7 +35,9 @@ ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -47,7 +49,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class PtychographicTomography( +class PtychographicTomographyReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py similarity index 100% rename from py4DSTEM/process/phase/ptychographic_visualizations.py rename to py4DSTEM/process/phase/iterative_ptychographic_visualizations.py diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py similarity index 98% rename from py4DSTEM/process/phase/singleslice_ptychography.py rename to py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 57b8ec93c..b60157d35 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -18,18 +18,20 @@ from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.ptychographic_constraints import ( +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.ptychographic_methods import ( +from py4DSTEM.process.phase.iterative_ptychographic_methods import ( ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( + VisualizationsMixin, +) from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -41,7 +43,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class SingleslicePtychography( +class SingleslicePtychographicReconstruction( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index ff9982b44..8744ec792 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform from skopt import gp_minimize from skopt.plots import plot_convergence as skopt_plot_convergence From 1210c918bd1abf1411703a6517c18bacdca9a043 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 16:20:21 -0800 Subject: [PATCH 063/128] removed leading iterative_ from file names and trailing Reconstruction from class names Former-commit-id: 9c4c51f89c51954b8626c3eaaa3de7fb67c65f34 --- py4DSTEM/process/phase/__init__.py | 18 +++++++++--------- .../process/phase/{iterative_dpc.py => dpc.py} | 4 ++-- ...py => magnetic_ptychographic_tomography.py} | 8 ++++---- ...tychography.py => magnetic_ptychography.py} | 12 +++++------- ...y => mixedstate_multislice_ptychography.py} | 12 +++++------- ...chography.py => mixedstate_ptychography.py} | 12 +++++------- ...chography.py => multislice_ptychography.py} | 12 +++++------- .../{iterative_parallax.py => parallax.py} | 4 ++-- py4DSTEM/process/phase/parameter_optimize.py | 2 +- ...ative_base_class.py => phase_base_class.py} | 0 ...traints.py => ptychographic_constraints.py} | 0 ...hic_methods.py => ptychographic_methods.py} | 0 ...mography.py => ptychographic_tomography.py} | 12 +++++------- ...ions.py => ptychographic_visualizations.py} | 0 ...hography.py => singleslice_ptychography.py} | 12 +++++------- 15 files changed, 48 insertions(+), 60 deletions(-) rename py4DSTEM/process/phase/{iterative_dpc.py => dpc.py} (99%) rename py4DSTEM/process/phase/{iterative_magnetic_ptychographic_tomography.py => magnetic_ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{iterative_magnetic_ptychography.py => magnetic_ptychography.py} (99%) rename py4DSTEM/process/phase/{iterative_mixedstate_multislice_ptychography.py => mixedstate_multislice_ptychography.py} (99%) rename py4DSTEM/process/phase/{iterative_mixedstate_ptychography.py => mixedstate_ptychography.py} (98%) rename py4DSTEM/process/phase/{iterative_multislice_ptychography.py => multislice_ptychography.py} (99%) rename py4DSTEM/process/phase/{iterative_parallax.py => parallax.py} (99%) rename py4DSTEM/process/phase/{iterative_base_class.py => phase_base_class.py} (100%) rename py4DSTEM/process/phase/{iterative_ptychographic_constraints.py => ptychographic_constraints.py} (100%) rename py4DSTEM/process/phase/{iterative_ptychographic_methods.py => ptychographic_methods.py} (100%) rename py4DSTEM/process/phase/{iterative_ptychographic_tomography.py => ptychographic_tomography.py} (99%) rename py4DSTEM/process/phase/{iterative_ptychographic_visualizations.py => ptychographic_visualizations.py} (100%) rename py4DSTEM/process/phase/{iterative_singleslice_ptychography.py => singleslice_ptychography.py} (98%) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 2069ffebf..59da42559 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,15 +2,15 @@ _emd_hook = True -from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_magnetic_ptychographic_tomography import MagneticPtychographicTomographyReconstruction -from py4DSTEM.process.phase.iterative_magnetic_ptychography import MagneticPtychographicReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction -from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_tomography import PtychographicTomographyReconstruction -from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.dpc import DPC +from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography +from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography +from py4DSTEM.process.phase.mixedstate_multislice_ptychography import MixedstateMultislicePtychography +from py4DSTEM.process.phase.mixedstate_ptychography import MixedstatePtychography +from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography +from py4DSTEM.process.phase.parallax import Parallax +from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychographic from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/dpc.py similarity index 99% rename from py4DSTEM/process/phase/iterative_dpc.py rename to py4DSTEM/process/phase/dpc.py index 11adc0c70..7043afc9b 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -19,12 +19,12 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction warnings.simplefilter(action="always", category=UserWarning) -class DPCReconstruction(PhaseReconstruction): +class DPC(PhaseReconstruction): """ Iterative Differential Phase Constrast Reconstruction Class. diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py rename to py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 816f7185b..090201897 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -19,14 +19,14 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( Object3DMethodsMixin, ObjectNDMethodsMixin, ProbeMethodsMixin, @@ -45,7 +45,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MagneticPtychographicTomographyReconstruction( +class MagneticPtychographicTomography( PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_magnetic_ptychography.py rename to py4DSTEM/process/phase/magnetic_ptychography.py index 4868f7917..0e633ae71 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -23,21 +23,19 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( MultipleMeasurementsMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,7 +47,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MagneticPtychographicReconstruction( +class MagneticPtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py rename to py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index aa6f4fe56..931985224 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -18,15 +18,15 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( Object2p5DMethodsMixin, Object2p5DProbeMixedMethodsMixin, ObjectNDMethodsMixin, @@ -35,9 +35,7 @@ ProbeMethodsMixin, ProbeMixedMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,7 +47,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstateMultislicePtychographicReconstruction( +class MixedstateMultislicePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py similarity index 98% rename from py4DSTEM/process/phase/iterative_mixedstate_ptychography.py rename to py4DSTEM/process/phase/mixedstate_ptychography.py index a57236bc5..dbcb62e97 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -18,23 +18,21 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ProbeMixedConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ObjectNDProbeMixedMethodsMixin, ProbeMethodsMixin, ProbeMixedMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -46,7 +44,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MixedstatePtychographicReconstruction( +class MixedstatePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeMixedConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_multislice_ptychography.py rename to py4DSTEM/process/phase/multislice_ptychography.py index cd622f5e7..69ad11330 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -18,23 +18,21 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -46,7 +44,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class MultislicePtychographicReconstruction( +class MultislicePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/parallax.py similarity index 99% rename from py4DSTEM/process/phase/iterative_parallax.py rename to py4DSTEM/process/phase/parallax.py index ba8fab238..0f32470d6 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -14,7 +14,7 @@ from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import ( AffineTransform, bilinear_kernel_density_estimate, @@ -56,7 +56,7 @@ } -class ParallaxReconstruction(PhaseReconstruction): +class Parallax(PhaseReconstruction): """ Iterative parallax reconstruction class. diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 8744ec792..ff9982b44 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform from skopt import gp_minimize from skopt.plots import plot_convergence as skopt_plot_convergence diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/phase_base_class.py similarity index 100% rename from py4DSTEM/process/phase/iterative_base_class.py rename to py4DSTEM/process/phase/phase_base_class.py diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py similarity index 100% rename from py4DSTEM/process/phase/iterative_ptychographic_constraints.py rename to py4DSTEM/process/phase/ptychographic_constraints.py diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py similarity index 100% rename from py4DSTEM/process/phase/iterative_ptychographic_methods.py rename to py4DSTEM/process/phase/ptychographic_methods.py diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py similarity index 99% rename from py4DSTEM/process/phase/iterative_ptychographic_tomography.py rename to py4DSTEM/process/phase/ptychographic_tomography.py index a76095d1e..1e1cb62b7 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -18,15 +18,15 @@ from emdfile import Custom, tqdmnd from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( Object2p5DConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( MultipleMeasurementsMethodsMixin, Object2p5DMethodsMixin, Object2p5DProbeMethodsMixin, @@ -35,9 +35,7 @@ ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -49,7 +47,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class PtychographicTomographyReconstruction( +class PtychographicTomography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py similarity index 100% rename from py4DSTEM/process/phase/iterative_ptychographic_visualizations.py rename to py4DSTEM/process/phase/ptychographic_visualizations.py diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py similarity index 98% rename from py4DSTEM/process/phase/iterative_singleslice_ptychography.py rename to py4DSTEM/process/phase/singleslice_ptychography.py index b60157d35..57b8ec93c 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -18,20 +18,18 @@ from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_methods import ( +from py4DSTEM.process.phase.ptychographic_methods import ( ObjectNDMethodsMixin, ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) -from py4DSTEM.process.phase.iterative_ptychographic_visualizations import ( - VisualizationsMixin, -) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, @@ -43,7 +41,7 @@ warnings.simplefilter(action="always", category=UserWarning) -class SingleslicePtychographicReconstruction( +class SingleslicePtychography( VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, From b6cac14222161de6bbbe318a4d1570d81697485f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 16:23:09 -0800 Subject: [PATCH 064/128] typo Former-commit-id: 040330acafec2f13300c4b1d8b051e400bed16a6 --- py4DSTEM/process/phase/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 59da42559..ecfeaa1d2 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -10,7 +10,7 @@ from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography from py4DSTEM.process.phase.parallax import Parallax from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography -from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychographic +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on From 8842819a7669e3d1ce4b354f9ae5cb4e91fb4a5d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 17:04:19 -0800 Subject: [PATCH 065/128] FFT plotting improvements Former-commit-id: 6d7fd3cd0c7a190f5cdf3efef81c1484c27cfa92 --- .../process/phase/ptychographic_methods.py | 161 ++++++++++-------- .../process/phase/ptychographic_tomography.py | 2 +- 2 files changed, 95 insertions(+), 68 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index d2a63ce9a..b809c3d1b 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -129,6 +129,8 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, + apply_hanning_window=False, + **kwargs, ): """ Returns absolute value of obj fft shifted to center of array @@ -137,6 +139,8 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT Returns ------- @@ -153,34 +157,82 @@ def _return_object_fft( obj = xp.angle(obj) obj = self._crop_rotate_object_fov(asnumpy(obj)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) - def show_object_fft(self, obj=None, **kwargs): + def show_object_fft( + self, + obj=None, + apply_hanning_window=True, + crop_to_min_frequency=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): """ Plot FFT of reconstructed object Parameters ---------- obj: complex array, optional - if None is specified, uses the `object_fft` property + If None is specified, uses the `object_fft` property + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + crop_to_min_frequency: bool, optional + If True, a square FFT is plotted, cropping to the smallest axis + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is object FFT sampling """ - if obj is None: - object_fft = self.object_fft - else: - object_fft = self._return_object_fft(obj) - figsize = kwargs.pop("figsize", (6, 6)) + object_fft = self._return_object_fft( + obj, apply_hanning_window=apply_hanning_window, **kwargs + ) + + if pixelsize is None: + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + if crop_to_min_frequency: + sx, sy = object_fft.shape + s = min(sx, sy) + start_x = sx // 2 - (s // 2) + start_y = sy // 2 - (s // 2) + object_fft = object_fft[start_x : start_x + s, start_y : start_y + s] + + figsize = kwargs.pop("figsize", (4, 4)) cmap = kwargs.pop("cmap", "magma") + ticks = kwargs.pop("ticks", False) + vmin = kwargs.pop("vmin", 0.001) + vmax = kwargs.pop("vmax", 0.999) + + # remove additional 3D FFT parameters before passing to show + kwargs.pop("projection_angle_deg", None) + kwargs.pop("projection_axes", None) + kwargs.pop("x_lims", None) + kwargs.pop("y_lims", None) - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - scalebar=True, + scalebar=scalebar, pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", + ticks=ticks, + pixelunits=pixelunits, + vmin=vmin, + vmax=vmax, **kwargs, ) @@ -391,6 +443,8 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, + apply_hanning_window=False, + **kwargs, ): """ Returns obj fft shifted to center of array @@ -399,6 +453,13 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. """ xp = self._xp @@ -409,6 +470,13 @@ def _return_object_fft( obj = xp.angle(obj) obj = self._crop_rotate_object_fov(obj.sum(axis=0)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) def show_depth_section( @@ -880,10 +948,12 @@ def _return_projected_cropped_potential( def _return_object_fft( self, obj=None, + apply_hanning_window=False, projection_angle_deg: float = None, projection_axes: Tuple[int, int] = (0, 2), x_lims: Tuple[int, int] = (None, None), y_lims: Tuple[int, int] = (None, None), + **kwargs, ): """ Returns obj fft shifted to center of array @@ -892,6 +962,8 @@ def _return_object_fft( ---------- obj: array, optional if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT projection_angle_deg: float Angle in degrees to rotate 3D array around prior to projection projection_axes: tuple(int,int) @@ -900,6 +972,11 @@ def _return_object_fft( min/max x indices y_lims: tuple(float,float) min/max y indices + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. """ xp = self._xp @@ -926,63 +1003,13 @@ def _return_object_fft( rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims ) - return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object + if apply_hanning_window: + sx, sy = rotated_object.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + rotated_object *= wx[:, None] * wy[None, :] - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) + return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) @property def object_supersliced(self): diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 1e1cb62b7..165dd115e 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -612,7 +612,7 @@ def preprocess( ) # propagated - propagated_probe = self._probe.copy() + propagated_probe = self._probes_all[0].copy() for s in range(self._num_slices - 1): propagated_probe = self._propagate_array( From f5446cde48488e4c87b1ab451a35ec7af051b14c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 20:25:35 -0800 Subject: [PATCH 066/128] switched divergence field functions to be periodic Former-commit-id: 8d85db84ee6e5f63dda7336df7ceb73bcbae4d46 --- py4DSTEM/process/phase/utils.py | 122 +++++++++++--------------------- 1 file changed, 43 insertions(+), 79 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a9b46bc9c..fc1715ab2 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -3,14 +3,13 @@ import matplotlib.pyplot as plt import numpy as np -from scipy.fft import dctn, dstn, idctn, idstn +from scipy.fft import dctn, idctn from scipy.optimize import curve_fit try: import cupy as cp from cupyx.scipy.fft import dctn as dctn_cp from cupyx.scipy.fft import idctn as idctn_cp - from cupyx.scipy.fft import rfft except (ImportError, ModuleNotFoundError): cp = None @@ -1152,114 +1151,79 @@ def fourier_rotate_real_volume(array, angle, axes=(0, 1), xp=np): return output_arr +def array_slice(axis, ndim, start, end, step=1): + """Returns array slice along dynamic axis""" + return (slice(None),) * (axis % ndim) + (slice(start, end, step),) + + ### Divergence Projection Functions -def compute_divergence(vector_field, spacings, xp=np): +def periodic_centered_difference(array, spacing, axis, xp=np): + """Computes second-order centered difference with periodic BCs""" + return (xp.roll(array, -1, axis=axis) - xp.roll(array, 1, axis=axis)) / ( + 2 * spacing + ) + + +def compute_divergence_periodic(vector_field, spacings, xp=np): """Computes divergence of vector_field""" num_dims = len(spacings) div = xp.zeros_like(vector_field[0]) for i in range(num_dims): - div += xp.gradient(vector_field[i], spacings[i], axis=i) + div += periodic_centered_difference(vector_field[i], spacings[i], axis=i, xp=xp) return div -def compute_gradient(scalar_field, spacings, xp=np): +def compute_gradient_periodic(scalar_field, spacings, xp=np): """Computes gradient of scalar_field""" num_dims = len(spacings) grad = xp.zeros((num_dims,) + scalar_field.shape) for i in range(num_dims): - grad[i] = xp.gradient(scalar_field, spacings[i], axis=i) + grad[i] = periodic_centered_difference(scalar_field, spacings[i], axis=i, xp=xp) return grad -def array_slice(axis, ndim, start, end, step=1): - """Returns array slice along dynamic axis""" - return (slice(None),) * (axis % ndim) + (slice(start, end, step),) - - -def make_array_rfft_compatible(array_nd, axis=0, xp=np): - """Expand array to be rfft compatible""" - array_shape = np.array(array_nd.shape) - d = array_nd.ndim - n = array_shape[axis] - array_shape[axis] = (n + 1) * 2 - - dtype = array_nd.dtype - padded_array = xp.zeros(array_shape, dtype=dtype) - - padded_array[array_slice(axis, d, 1, n + 1)] = -array_nd - padded_array[array_slice(axis, d, None, -n - 1, -1)] = array_nd - - return padded_array +def preconditioned_laplacian_periodic_3D(shape, xp=np): + """FFT eigenvalues""" + n, m, p = shape + i, j, k = xp.ogrid[0:n, 0:m, 0:p] - -def dst_I(array_nd, xp=np): - """1D rfft-based DST-I""" - d = array_nd.ndim - for axis in range(d): - crop_slice = array_slice(axis, d, 1, -1) - array_nd = rfft( - make_array_rfft_compatible(array_nd, axis=axis, xp=xp), axis=axis - )[crop_slice].imag - - return array_nd - - -def idst_I(array_nd, xp=np): - """1D rfft-based iDST-I""" - scaling = np.prod((np.array(array_nd.shape) + 1) * 2) - return dst_I(array_nd, xp=xp) / scaling - - -def preconditioned_laplacian(num_exterior, spacing=1, xp=np): - """DST-I eigenvalues""" - n = num_exterior - 1 - evals_1d = 2 - 2 * xp.cos(np.pi * xp.arange(1, num_exterior) / num_exterior) - - op = ( - xp.repeat(evals_1d, n**2) - + xp.tile(evals_1d, n**2) - + xp.tile(xp.repeat(evals_1d, n), n) + op = 6 - 2 * xp.cos(2 * np.pi * i / n) * xp.cos(2 * np.pi * j / m) * xp.cos( + 2 * np.pi * k / p ) - - return -op / spacing**2 + op[0, 0, 0] = 1 # gauge invariance + return -op -def preconditioned_poisson_solver(rhs_interior, spacing=1, xp=np): - """DST-I based poisson solver""" - nx, ny, nz = rhs_interior.shape - if nx != ny or nx != nz: - raise ValueError() +def preconditioned_poisson_solver_periodic_3D(rhs, gauge=None, xp=np): + """FFT based poisson solver""" + op = preconditioned_laplacian_periodic_3D(rhs.shape, xp=xp) - op = preconditioned_laplacian(nx + 1, spacing=spacing, xp=xp) - if xp is np: - dst_rhs = dstn(rhs_interior, type=1).ravel() - dst_u = (dst_rhs / op).reshape((nx, ny, nz)) - sol = idstn(dst_u, type=1) - else: - dst_rhs = dst_I(rhs_interior, xp=xp).ravel() - dst_u = (dst_rhs / op).reshape((nx, ny, nz)) - sol = idst_I(dst_u, xp=xp) + if gauge is None: + gauge = xp.mean(rhs) + fft_rhs = xp.fft.fftn(rhs) + fft_rhs[0, 0, 0] = gauge # gauge invariance + sol = xp.fft.ifftn(fft_rhs / op).real return sol -def project_vector_field_divergence(vector_field, spacings=(1, 1, 1), xp=np): +def project_vector_field_divergence_periodic_3D(vector_field, xp=np): """ Returns solenoidal part of vector field using projection: f - \\grad{p} s.t. \\laplacian{p} = \\div{f} """ - - div_v = compute_divergence(vector_field, spacings, xp=xp) - p = preconditioned_poisson_solver(div_v, spacings[0], xp=xp) - grad_p = compute_gradient(p, spacings, xp=xp) + spacings = (1, 1, 1) + div_v = compute_divergence_periodic(vector_field, spacings, xp=xp) + p = preconditioned_poisson_solver_periodic_3D(div_v, xp=xp) + grad_p = compute_gradient_periodic(p, spacings, xp=xp) return vector_field - grad_p @@ -1612,7 +1576,7 @@ def aberrations_basis_function( return aberrations_basis, aberrations_mn -def preconditioned_laplacian_dct(shape, xp=np): +def preconditioned_laplacian_neumann_2D(shape, xp=np): """DCT eigenvalues""" n, m = shape i, j = xp.ogrid[0:n, 0:m] @@ -1622,9 +1586,9 @@ def preconditioned_laplacian_dct(shape, xp=np): return -op -def preconditioned_poisson_solver_dct(rhs, gauge=None, xp=np): +def preconditioned_poisson_solver_neumann_2D(rhs, gauge=None, xp=np): """DCT based poisson solver""" - op = preconditioned_laplacian_dct(rhs.shape, xp=xp) + op = preconditioned_laplacian_neumann_2D(rhs.shape, xp=xp) if gauge is None: gauge = xp.mean(rhs) @@ -1638,7 +1602,7 @@ def preconditioned_poisson_solver_dct(rhs, gauge=None, xp=np): fft_rhs = dctn_xp(rhs, type=2) fft_rhs[0, 0] = gauge # gauge invariance - sol = idctn_xp(fft_rhs / op, type=2) + sol = idctn_xp(fft_rhs / op, type=2).real return sol @@ -1668,7 +1632,7 @@ def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np rho = xp.diff(dx, axis=0, prepend=0, append=0) rho += xp.diff(dy, axis=1, prepend=0, append=0) - unwrapped_array = preconditioned_poisson_solver_dct(rho, gauge=gauge, xp=xp).real + unwrapped_array = preconditioned_poisson_solver_neumann_2D(rho, gauge=gauge, xp=xp) unwrapped_array -= unwrapped_array.min() if corner_centered: From eb1de01d3136fffe963b2268f019e442c611fec7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 20:27:50 -0800 Subject: [PATCH 067/128] magnetic ptycho-tomo preprocessing Former-commit-id: f4bc14819ba99a645a856d2f1302868dc23bae71 --- .../magnetic_ptychographic_tomography.py | 1756 ++--------------- .../process/phase/ptychographic_tomography.py | 14 +- 2 files changed, 174 insertions(+), 1596 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 090201897..831edd904 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -8,7 +8,6 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg @@ -21,37 +20,47 @@ from py4DSTEM import DataCube from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, Object3DConstraintsMixin, ObjectNDConstraintsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, ) from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, Object3DMethodsMixin, ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, fft_shift, generate_batches, polar_aliases, polar_symbols, - project_vector_field_divergence, - spatial_frequencies, + project_vector_field_divergence_periodic_3D, ) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar warnings.simplefilter(action="always", category=UserWarning) class MagneticPtychographicTomography( + VisualizationsMixin, PositionsConstraintsMixin, ProbeConstraintsMixin, Object3DConstraintsMixin, + Object2p5DConstraintsMixin, ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, Object3DMethodsMixin, + Object2p5DMethodsMixin, ObjectNDMethodsMixin, PtychographicReconstruction, ): @@ -73,13 +82,9 @@ class MagneticPtychographicTomography( energy: float The electron energy of the wave functions in eV num_slices: int - Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of (\alpha, \beta) tilt angle tuple in degrees, - with the following Euler-angle convention: - - \alpha tilt around z-axis - - \beta tilt around x-axis - - -\alpha tilt around z-axis + Number of super-slices to use in the forward model + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -119,13 +124,13 @@ class MagneticPtychographicTomography( """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[Tuple[float, float]], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -148,19 +153,32 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform + from scipy.special import erf + + self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import ( + affine_transform, + gaussian_filter, + rotate, + zoom, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform + from cupyx.scipy.special import erf + + self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -176,7 +194,7 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts @@ -188,7 +206,7 @@ def __init__( # Data self._datacube = datacube self._object = initial_object_guess - self._probe = initial_probe_guess + self._probe_init = initial_probe_guess # Common Metadata self._vacuum_probe_intensity = vacuum_probe_intensity @@ -206,234 +224,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) - self._num_tilts = num_tilts - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def _project_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(input_z / output_z).astype("int") - pad_size = voxels_per_slice * output_z - input_z - - padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) - - return xp.sum( - padded_array.reshape( - ( - -1, - voxels_per_slice, - ) - + array.shape[1:] - ), - axis=1, - ) - - def _expand_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(output_z / input_z).astype("int") - remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) - - voxels_in_slice = xp.repeat(voxels_per_slice, input_z) - voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice - - normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] - return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] - - def _euler_angle_rotate_volume( - self, - volume_array, - alpha_deg, - beta_deg, - ): - """ - Rotate 3D volume using alpha, beta, gamma Euler angles according to convention: - - - \\-alpha tilt around first axis (z) - - \\beta tilt around second axis (x) - - \\alpha tilt around first axis (z) - - Note: since we store array as zxy, the x- and y-axis rotations flip sign below. - - """ - - rotate = self._rotate - volume = volume_array.copy() - - alpha_deg, beta_deg = np.mod(np.array([alpha_deg, beta_deg]) + 180, 360) - 180 - - if alpha_deg == -180: - # print(f"rotation of {-beta_deg} around x") - volume = rotate( - volume, - beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - elif alpha_deg == -90: - # print(f"rotation of {beta_deg} around y") - volume = rotate( - volume, - -beta_deg, - axes=(0, 1), - reshape=False, - order=3, - ) - elif alpha_deg == 0: - # print(f"rotation of {beta_deg} around x") - volume = rotate( - volume, - -beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - elif alpha_deg == 90: - # print(f"rotation of {-beta_deg} around y") - volume = rotate( - volume, - beta_deg, - axes=(0, 1), - reshape=False, - order=3, - ) - else: - # print(( - # f"rotation of {-alpha_deg} around z, " - # f"rotation of {beta_deg} around x, " - # f"rotation of {alpha_deg} around z." - # )) - - volume = rotate( - volume, - -alpha_deg, - axes=(1, 2), - reshape=False, - order=3, - ) - - volume = rotate( - volume, - -beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - - volume = rotate( - volume, - alpha_deg, - axes=(1, 2), - reshape=False, - order=3, - ) - - return volume + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) + self._num_measurements = num_tilts def preprocess( self, @@ -523,7 +315,7 @@ def preprocess( ) if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") if self._positions_mask.ndim == 2: warnings.warn( @@ -531,97 +323,91 @@ def preprocess( UserWarning, ) self._positions_mask = np.tile( - self._positions_mask, (self._num_tilts, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array."), - UserWarning, + self._positions_mask, (self._num_measurements, 1, 1) ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_tilts - - # Prepopulate various arrays - - if self._positions_mask[0] is None: - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - num_probes_per_tilt = np.array(num_probes_per_tilt) - else: - num_probes_per_tilt = np.insert( + num_probes_per_measurement = np.insert( self._positions_mask.sum(axis=(-2, -1)), 0, 0 ) - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + # prepopulate relevant arrays self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) - self._rotation_best_transpose = diffraction_patterns_transpose + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if probe_roi_shape is not None: + roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) + self._region_of_interest_shape = np.array(roi_shape) + + # TO-DO: generalize this if force_com_shifts is None: - force_com_shifts = [None] * self._num_tilts + force_com_shifts = [None] * self._num_measurements + + self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) + self._rotation_best_transpose = diffraction_patterns_transpose - for tilt_index in tqdmnd( - self._num_tilts, + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, desc="Preprocessing data", unit="tilt", disable=not progress_bar, ): - if tilt_index == 0: + # preprocess datacube, vacuum and masks only for first tilt + if index == 0: ( - self._datacube[tilt_index], + self._datacube[index], self._vacuum_probe_intensity, self._dp_mask, - force_com_shifts[tilt_index], + force_com_shifts[index], ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], + self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, probe_roi_shape=self._probe_roi_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, - com_shifts=force_com_shifts[tilt_index], - ) - - self._amplitudes = xp.empty( - (self._num_diffraction_patterns,) + self._datacube[0].Qshape - ) - self._region_of_interest_shape = np.array( - self._amplitudes[0].shape[-2:] + com_shifts=force_com_shifts[index], ) else: ( - self._datacube[tilt_index], + self._datacube[index], _, _, - force_com_shifts[tilt_index], + force_com_shifts[index], ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], + self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, probe_roi_shape=self._probe_roi_shape, vacuum_probe_intensity=None, dp_mask=None, - com_shifts=force_com_shifts[tilt_index], + com_shifts=force_com_shifts[index], ) + # calibrations intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube[tilt_index], + self._datacube[index], require_calibrations=True, force_scan_sampling=force_scan_sampling, force_angular_sampling=force_angular_sampling, force_reciprocal_sampling=force_reciprocal_sampling, ) + # calculate CoM ( com_measured_x, com_measured_y, @@ -633,22 +419,22 @@ def preprocess( intensities, dp_mask=self._dp_mask, fit_function=fit_function, - com_shifts=force_com_shifts[tilt_index], + com_shifts=force_com_shifts[index], ) + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] ( - self._amplitudes[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ], + self._amplitudes[idx_start:idx_end], mean_diffraction_intensity_temp, + self._crop_mask, ) = self._normalize_diffraction_intensities( intensities, com_fitted_x, com_fitted_y, + self._positions_mask[index], crop_patterns, - self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -663,12 +449,14 @@ def preprocess( com_normalized_y, ) - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], self._positions_mask[tilt_index] + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, ) # handle semiangle specified in pixels @@ -677,119 +465,64 @@ def preprocess( self._semiangle_cutoff_pixels * self._angular_sampling[0] ) - # Object Initialization + # initialize object + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + main_tilt_axis=None, + ) + if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px_all, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - self._object = xp.zeros((4, q, p, q), dtype=xp.float32) + self._object = xp.full((4,) + obj.shape, obj) else: - self._object = xp.asarray(self._object, dtype=xp.float32) + self._object = obj self._object_initial = self._object.copy() self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] self._num_voxels = self._object.shape[1] - # Center Probes + # center probe positions self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) - for tilt_index in range(self._num_tilts): - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + self._positions_px = self._positions_px_all[idx_start:idx_end] self._positions_px_com = xp.mean(self._positions_px, axis=0) self._positions_px -= ( self._positions_px_com - xp.array(self._object_shape) / 2 ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._positions_px.copy() + self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() self._positions_initial_all[:, 0] *= self.sampling[0] self._positions_initial_all[:, 1] *= self.sampling[1] - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, ) - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + del self._probe_init + # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -799,9 +532,12 @@ def preprocess( )._evaluate_ctf() # Precomputed propagator arrays + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + self._slice_thicknesses = np.tile( - self._object_shape[1] * self.sampling[1] / self._num_slices, - self._num_slices - 1, + thickness / self._num_slices, self._num_slices - 1 ) self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, @@ -813,26 +549,24 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object[0]) + old_rot_matrix = np.eye(3) # identity - for tilt_index in np.arange(self._num_tilts): - alpha_deg, beta_deg = self._tilt_angles_deg[tilt_index] + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + rot_matrix = self._tilt_orientation_matrices[index] - probe_overlap_3D = self._euler_angle_rotate_volume( + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - alpha_deg, - beta_deg, + rot_matrix @ old_rot_matrix.T, ) - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] + self._positions_px = self._positions_px_all[idx_start:idx_end] self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) shifted_probes = fft_shift( - self._probe, self._positions_px_fractional, xp + self._probes_all[index], self._positions_px_fractional, xp ) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts( @@ -840,43 +574,49 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix - probe_overlap_3D = self._euler_angle_rotate_volume( - probe_overlap_3D, - alpha_deg, - -beta_deg, - ) + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) - probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) + probe_overlap_3D_blurred = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( - probe_overlap_3D > 0.25 * probe_overlap_3D.max() + probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max() ) + else: self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] + + self._positions_px = self._positions_px_all[ + : self._cum_probes_per_measurement[1] + ] self._positions_px_fractional = self._positions_px - xp.round( self._positions_px ) - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) + shifted_probes = fft_shift( + self._probes_all[0], self._positions_px_fractional, xp + ) probe_intensities = xp.abs(shifted_probes) ** 2 probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) # initial probe complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, + self.probe_centered[0], + power=power, chroma_boost=chroma_boost, ) # propagated - propagated_probe = self._probe.copy() + propagated_probe = self._probes_all[0].copy() for s in range(self._num_slices - 1): propagated_probe = self._propagate_array( @@ -884,7 +624,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - power=2, + power=power, chroma_boost=chroma_boost, ) @@ -961,685 +701,6 @@ def preprocess( return self - def _overlap_projection( - self, current_object_V, current_object_A_projected, current_probe - ): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - complex_object = xp.exp(1j * (current_object_V + current_object_A_projected)) - object_patches = complex_object[ - :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object_V, - current_object_A_projected, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes:np.ndarray - Prop[object^n*probe^n] - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection( - current_object_V, - current_object_A_projected, - current_probe, - ) - - if use_projection_scheme: - ( - exit_waves[self._active_tilt_index], - error, - ) = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves[self._active_tilt_index], - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - object_update = step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - current_object_V[s] += object_update - current_object_A_projected[s] += object_update - - # back-transmit - exit_waves *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object_V, current_object_A_projected, current_probe - - def _projection_sets_adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - object_update = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - current_object_V[s] = object_update - current_object_A_projected[s] = object_update - - # back-transmit - exit_waves_copy *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object_V, current_object_A_projected, current_probe - - def _adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - ( - current_object_V, - current_object_A_projected, - current_probe, - ) = self._projection_sets_adjoint( - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves[self._active_tilt_index], - normalization_min, - fix_probe, - ) - else: - ( - current_object_V, - current_object_A_projected, - current_probe, - ) = self._gradient_descent_adjoint( - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object_V, current_object_A_projected, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes at each layer - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes[-1]) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes[-1].shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - complex_object = xp.exp(1j * current_object) - - # dx - propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes - - # propagate - if s + 1 < self._num_slices: - propagated_probes = self._propagate_array( - transmitted_probes_perturbed[s], self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2( - transmitted_probes_perturbed[-1] - ) - - # dy - propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes - - # propagate - if s + 1 < self._num_slices: - propagated_probes = self._propagate_array( - transmitted_probes_perturbed[s], self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2( - transmitted_probes_perturbed[-1] - ) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - def _divergence_free_constraint(self, vector_field): """ Leray projection operator @@ -1647,19 +708,16 @@ def _divergence_free_constraint(self, vector_field): Parameters -------- vector_field: np.ndarray - Current object vector as Az, Ax, Ay + Current object vector as Ax, Ay, Az Returns -------- projected_vector_field: np.ndarray - Divergence-less object vector as Az, Ax, Ay + Divergence-less object vector as Ax, Ay, Az """ xp = self._xp - spacings = (self.sampling[1],) + self.sampling - vector_field = project_vector_field_divergence( - vector_field, spacings=spacings, xp=xp - ) + vector_field = project_vector_field_divergence_periodic_3D(vector_field, xp=xp) return vector_field @@ -2588,491 +1646,3 @@ def reconstruct( xp.clear_memo() return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - cmap = kwargs.pop("cmap", "magma") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object[0], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object[0] - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - figsize = kwargs.pop("figsize", (14, 10) if cbar else (12, 10)) - cmap_e = kwargs.pop("cmap_e", "magma") - cmap_m = kwargs.pop("cmap_m", "PuOr") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj_V = self._rotate( - self._object[0], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Az = self._rotate( - self._object[1], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Ax = self._rotate( - self._object[2], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Ay = self._rotate( - self._object[3], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_V = asnumpy(rotated_3d_obj_V) - rotated_3d_obj_Az = asnumpy(rotated_3d_obj_Az) - rotated_3d_obj_Ax = asnumpy(rotated_3d_obj_Ax) - rotated_3d_obj_Ay = asnumpy(rotated_3d_obj_Ay) - else: - ( - rotated_3d_obj_V, - rotated_3d_obj_Az, - rotated_3d_obj_Ax, - rotated_3d_obj_Ay, - ) = self.object - - rotated_object_Vx = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Vy = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Vz = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Azx = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Azy = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Azz = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Axx = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Axy = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Axz = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Ayx = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Ayy = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Ayz = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_shape = rotated_object_Vx.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - arrays = [ - [ - rotated_object_Vx, - rotated_object_Axx, - rotated_object_Ayx, - rotated_object_Azx, - ], - [ - rotated_object_Vy, - rotated_object_Axy, - rotated_object_Ayy, - rotated_object_Azy, - ], - [ - rotated_object_Vz, - rotated_object_Axz, - rotated_object_Ayz, - rotated_object_Azz, - ], - ] - - titles = [ - [ - "V projected along x", - "Ax projected along x", - "Ay projected along x", - "Az projected along x", - ], - [ - "V projected along y", - "Ax projected along y", - "Ay projected along y", - "Az projected along y", - ], - [ - "V projected along z", - "Ax projected along z", - "Ay projected along z", - "Az projected along z", - ], - ] - - max_e = np.array( - [rotated_object_Vx.max(), rotated_object_Vy.max(), rotated_object_Vz.max()] - ).max() - max_m = np.array( - [ - [ - np.abs(rotated_object_Axx).max(), - np.abs(rotated_object_Ayx).max(), - np.abs(rotated_object_Azx).max(), - ], - [ - np.abs(rotated_object_Axy).max(), - np.abs(rotated_object_Ayy).max(), - np.abs(rotated_object_Azy).max(), - ], - [ - np.abs(rotated_object_Axz).max(), - np.abs(rotated_object_Ayz).max(), - np.abs(rotated_object_Azz).max(), - ], - ] - ).max() - - vmin_e = kwargs.pop("vmin_e", 0.0) - vmax_e = kwargs.pop("vmax_e", max_e) - vmin_m = kwargs.pop("vmin_m", -max_m) - vmax_m = kwargs.pop("vmax_m", max_m) - - if plot_convergence: - spec = GridSpec( - ncols=4, nrows=4, height_ratios=[4, 4, 4, 1], hspace=0.15, wspace=0.35 - ) - else: - spec = GridSpec(ncols=4, nrows=3, hspace=0.15, wspace=0.35) - - if fig is None: - fig = plt.figure(figsize=figsize) - - for sp in spec: - row, col = np.unravel_index(sp.num1, (4, 4)) - - if row < 3: - ax = fig.add_subplot(sp) - if sp.is_first_col(): - cmap = cmap_e - vmin = vmin_e - vmax = vmax_e - else: - cmap = cmap_m - vmin = vmin_m - vmax = vmax_m - - im = ax.imshow( - arrays[row][col], - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - ax.set_title(titles[row][col]) - - if row < 2: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - if plot_convergence and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - - ax = fig.add_subplot(spec[-1, :]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - plot_convergence: bool, - iterations_grid: Tuple[int, int], - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - """ - raise NotImplementedError() - - def visualize( - self, - fig=None, - cbar: bool = True, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims=(None, None), - y_lims=(None, None), - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - - Returns - -------- - self: OverlapMagneticTomographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - cbar=cbar, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - cbar=cbar, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - - return self - - @property - def positions(self): - """Probe positions [A]""" - - if self.angular_sampling is None: - return None - - asnumpy = self._asnumpy - positions_all = [] - for tilt_index in range(self._num_tilts): - positions = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ].copy() - positions[:, 0] *= self.sampling[0] - positions[:, 1] *= self.sampling[1] - positions_all.append(asnumpy(positions)) - - return np.asarray(positions_all) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - raise NotImplementedError() diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 165dd115e..7d4964fa4 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -81,7 +81,7 @@ class PtychographicTomography( energy: float The electron energy of the wave functions in eV num_slices: int - Number of slices to use in the forward model + Number of super-slices to use in the forward model tilt_orientation_matrices: Sequence[np.ndarray] List of orientation matrices for each tilt semiangle_cutoff: float, optional @@ -531,9 +531,17 @@ def preprocess( )._evaluate_ctf() # Precomputed propagator arrays + if main_tilt_axis == "vertical": + thickness = self._object_shape[1] * self.sampling[1] + elif main_tilt_axis == "horizontal": + thickness = self._object_shape[0] * self.sampling[0] + else: + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + self._slice_thicknesses = np.tile( - self._object_shape[1] * self.sampling[1] / self._num_slices, - self._num_slices - 1, + thickness / self._num_slices, self._num_slices - 1 ) self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, From 9ec45b5d27c3fa458b2a9f27ba7ca4b35ba871e1 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 21:00:01 -0800 Subject: [PATCH 068/128] magnetic ptycho tomo constraints Former-commit-id: aac0cd5600bcab6fe925c2f54a722d397dd98e60 --- .../magnetic_ptychographic_tomography.py | 270 ++++-------------- 1 file changed, 48 insertions(+), 222 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 831edd904..b17eb075e 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -701,254 +701,80 @@ def preprocess( return self - def _divergence_free_constraint(self, vector_field): - """ - Leray projection operator - - Parameters - -------- - vector_field: np.ndarray - Current object vector as Ax, Ay, Az - - Returns - -------- - projected_vector_field: np.ndarray - Divergence-less object vector as Ax, Ay, Az - """ - xp = self._xp - - vector_field = project_vector_field_divergence_periodic_3D(vector_field, xp=xp) - - return vector_field - - def _constraints( + def _object_constraints_vector( self, current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, + pure_phase_object, gaussian_filter, gaussian_filter_sigma_e, gaussian_filter_sigma_m, butterworth_filter, + butterworth_order, q_lowpass_e, q_lowpass_m, q_highpass_e, q_highpass_m, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, tv_denoise, tv_denoise_weights, tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, ): - """ - Ptychographic constraints operator. - Calls _threshold_object_constraint() and _probe_center_of_mass_constraint() - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ + """Calls Object3DConstraints _object_constraints for each object.""" + xp = self._xp - if gaussian_filter: - current_object[0] = self._object_gaussian_constraint( - current_object[0], gaussian_filter_sigma_e, pure_phase_object=False - ) - current_object[1] = self._object_gaussian_constraint( - current_object[1], gaussian_filter_sigma_m, pure_phase_object=False - ) - current_object[2] = self._object_gaussian_constraint( - current_object[2], gaussian_filter_sigma_m, pure_phase_object=False - ) - current_object[3] = self._object_gaussian_constraint( - current_object[3], gaussian_filter_sigma_m, pure_phase_object=False - ) + # electrostatic + current_object[0] = self._object_constraints( + current_object[0], + gaussian_filter, + gaussian_filter_sigma_e, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_highpass_e, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ) - if butterworth_filter: - current_object[0] = self._object_butterworth_constraint( - current_object[0], - q_lowpass_e, - q_highpass_e, - butterworth_order, - ) - current_object[1] = self._object_butterworth_constraint( - current_object[1], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - current_object[2] = self._object_butterworth_constraint( - current_object[2], - q_lowpass_m, - q_highpass_m, + # magnetic + for index in range(1, 4): + current_object[index] = self._object_constraints( + current_object[index], + gaussian_filter, + gaussian_filter_sigma_m, + butterworth_filter, butterworth_order, - ) - current_object[3] = self._object_butterworth_constraint( - current_object[3], q_lowpass_m, q_highpass_m, - butterworth_order, - ) - - elif tv_denoise: - current_object[0] = self._object_denoise_tv_pylops( - current_object[0], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[1] = self._object_denoise_tv_pylops( - current_object[1], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[2] = self._object_denoise_tv_pylops( - current_object[2], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[3] = self._object_denoise_tv_pylops( - current_object[3], + tv_denoise, tv_denoise_weights, tv_denoise_inner_iter, + False, + 0.0, + None, + **kwargs, ) - if shrinkage_rad > 0.0 or object_mask is not None: - current_object[0] = self._object_shrinkage_constraint( - current_object[0], - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object[0] = self._object_positivity_constraint(current_object[0]) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) + # divergence-free + current_object[1:] = project_vector_field_divergence_periodic_3D( + current_object[1:], xp=xp + ) - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) + return current_object - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) + def _constraints(self, current_object, current_probe, current_positions, **kwargs): + """Wrapper function to bypass _object_constraints""" - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) + current_object = self._object_constraints_vector(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints(current_positions, **kwargs) return current_object, current_probe, current_positions From cc6c1938ed204decf17022112ed18704e1191d9a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 6 Jan 2024 22:24:48 -0800 Subject: [PATCH 069/128] magnetic ptycho tomo works Former-commit-id: 4ab72fe557baa30766fd2a9d43797fd8c54af9cb --- .../magnetic_ptychographic_tomography.py | 528 +++++++----------- 1 file changed, 187 insertions(+), 341 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index b17eb075e..3fceb6367 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -704,7 +704,6 @@ def preprocess( def _object_constraints_vector( self, current_object, - pure_phase_object, gaussian_filter, gaussian_filter_sigma_e, gaussian_filter_sigma_m, @@ -778,9 +777,22 @@ def _constraints(self, current_object, current_probe, current_positions, **kwarg return current_object, current_probe, current_positions + def _rotate_zxy_volume_vector( + self, + current_object, + rot_matrix, + ): + """ """ + for index in range(4): + current_object[index] = self._rotate_zxy_volume( + current_object[index], rot_matrix + ) + + return current_object + def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -801,7 +813,8 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, + max_position_update_distance: float = None, + max_position_total_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma_e: float = None, gaussian_filter_sigma_m: float = None, @@ -809,6 +822,7 @@ def reconstruct( fit_probe_aberrations_iter: int = 0, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, butterworth_filter_iter: int = np.inf, q_lowpass_e: float = None, q_lowpass_m: float = None, @@ -821,7 +835,7 @@ def reconstruct( tv_denoise_iter=np.inf, tv_denoise_weights=None, tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, + collective_measurement_updates: bool = True, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, @@ -879,9 +893,10 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma_e: float @@ -896,6 +911,8 @@ def reconstruct( Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -913,7 +930,7 @@ def reconstruct( the more denoising. tv_denoise_inner_iter: float Number of iterations to run in inner loop of TV denoising - collective_tilt_updates: bool + collective_measurement_updates: bool if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration @@ -932,184 +949,56 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, ) - if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: + if use_projection_scheme: raise NotImplementedError( - "Position correction is currently incompatible with collective updates." + "Magnetic ptychographic tomography is currently only implemented for gradient descent." ) - # Batching + if self._verbose: + self._report_reconstruction_summary( + max_iter, + np.inf, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) if max_batch_size is not None: xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) if gaussian_filter_sigma_m is None: gaussian_filter_sigma_m = gaussian_filter_sigma_e @@ -1117,6 +1006,9 @@ def reconstruct( if q_lowpass_m is None: q_lowpass_m = q_lowpass_e + if fix_positions_iter < 1: + fix_positions_iter = 1 # give position correction a chance + # main loop for a0 in tqdmnd( max_iter, @@ -1126,75 +1018,59 @@ def reconstruct( ): error = 0.0 - if collective_tilt_updates: + if collective_measurement_updates: collective_object = xp.zeros_like(self._object) - tilt_indices = np.arange(self._num_tilts) - np.random.shuffle(tilt_indices) - - for tilt_index in tilt_indices: - tilt_error = 0.0 - self._active_tilt_index = tilt_index + indices = np.arange(self._num_measurements) + np.random.shuffle(indices) - alpha_deg, beta_deg = self._tilt_angles_deg[self._active_tilt_index] - alpha, beta = np.deg2rad([alpha_deg, beta_deg]) - - # V - self._object[0] = self._euler_angle_rotate_volume( - self._object[0], - alpha_deg, - beta_deg, - ) + old_rot_matrix = np.eye(3) # identity - # Az - self._object[1] = self._euler_angle_rotate_volume( - self._object[1], - alpha_deg, - beta_deg, - ) + for index in indices: + self._active_measurement_index = index - # Ax - self._object[2] = self._euler_angle_rotate_volume( - self._object[2], - alpha_deg, - beta_deg, - ) + measurement_error = 0.0 - # Ay - self._object[3] = self._euler_angle_rotate_volume( - self._object[3], - alpha_deg, - beta_deg, + rot_matrix = self._tilt_orientation_matrices[ + self._active_measurement_index + ] + self._object = self._rotate_zxy_volume_vector( + self._object, + rot_matrix @ old_rot_matrix.T, ) - - object_A = self._object[1] * np.cos(beta) + np.sin(beta) * ( - self._object[3] * np.cos(alpha) - self._object[2] * np.sin(alpha) + object_V = self._object[0] + + # last transformation matrix row + weight_x, weight_y, weight_z = rot_matrix[-1] + object_A = ( + weight_x * self._object[2] + + weight_y * self._object[3] + + weight_z * self._object[1] ) - object_sliced_V = self._project_sliced_object( - self._object[0], self._num_slices + object_sliced = self._project_sliced_object( + object_V + object_A, self._num_slices ) - object_sliced_A = self._project_sliced_object( - object_A, self._num_slices - ) + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] if not use_projection_scheme: - object_sliced_old_V = object_sliced_V.copy() - object_sliced_old_A = object_sliced_A.copy() + object_sliced_old = object_sliced.copy() - start_tilt = self._cum_probes_per_tilt[self._active_tilt_index] - end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1] + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] - num_diffraction_patterns = end_tilt - start_tilt + num_diffraction_patterns = end_idx - start_idx shuffled_indices = np.arange(num_diffraction_patterns) unshuffled_indices = np.zeros_like(shuffled_indices) - if max_batch_size is None: - current_max_batch_size = num_diffraction_patterns - else: - current_max_batch_size = max_batch_size - # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) @@ -1203,15 +1079,15 @@ def reconstruct( num_diffraction_patterns ) - positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[ + positions_px = self._positions_px_all[start_idx:end_idx].copy()[ shuffled_indices ] initial_positions_px = self._positions_px_initial_all[ - start_tilt:end_tilt + start_idx:end_idx ].copy()[shuffled_indices] for start, end in generate_batches( - num_diffraction_patterns, max_batch=current_max_batch_size + num_diffraction_patterns, max_batch=max_batch_size ): # batch indices self._positions_px = positions_px[start:end] @@ -1226,21 +1102,20 @@ def reconstruct( self._vectorized_patch_indices_col, ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start_tilt:end_tilt][ + amplitudes = self._amplitudes[start_idx:end_idx][ shuffled_indices[start:end] ] # forward operator ( - propagated_probes, + shifted_probes, object_patches, - transmitted_probes, + overlap, self._exit_waves, batch_error, ) = self._forward( - object_sliced_V, - object_sliced_A, - self._probe, + object_sliced, + _probe, amplitudes, self._exit_waves, use_projection_scheme, @@ -1250,12 +1125,11 @@ def reconstruct( ) # adjoint operator - object_sliced_V, object_sliced_A, self._probe = self._adjoint( - object_sliced_V, - object_sliced_A, - self._probe, + object_sliced, _probe = self._adjoint( + object_sliced, + _probe, object_patches, - propagated_probes, + shifted_probes, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1266,100 +1140,91 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: positions_px[start:end] = self._position_correction( - object_sliced_V, - self._probe, - transmitted_probes, + object_sliced, + shifted_probes, + overlap, amplitudes, self._positions_px, + self._positions_px_initial, positions_step_size, - constrain_position_distance, + max_position_update_distance, + max_position_total_distance, ) - tilt_error += batch_error + measurement_error += batch_error if not use_projection_scheme: - object_sliced_V -= object_sliced_old_V - object_sliced_A -= object_sliced_old_A + object_sliced -= object_sliced_old - object_update_V = self._expand_sliced_object( - object_sliced_V, self._num_voxels - ) - object_update_A = self._expand_sliced_object( - object_sliced_A, self._num_voxels + object_update = self._expand_sliced_object( + object_sliced, self._num_voxels ) - if collective_tilt_updates: - collective_object[0] += self._euler_angle_rotate_volume( - object_update_V, - alpha_deg, - -beta_deg, - ) - collective_object[1] += self._euler_angle_rotate_volume( - object_update_A * np.cos(beta), - alpha_deg, - -beta_deg, - ) - collective_object[2] -= self._euler_angle_rotate_volume( - object_update_A * np.sin(alpha) * np.sin(beta), - alpha_deg, - -beta_deg, - ) - collective_object[3] += self._euler_angle_rotate_volume( - object_update_A * np.cos(alpha) * np.sin(beta), - alpha_deg, - -beta_deg, - ) - else: - self._object[0] += object_update_V - self._object[1] += object_update_A * np.cos(beta) - self._object[2] -= object_update_A * np.sin(alpha) * np.sin(beta) - self._object[3] += object_update_A * np.cos(alpha) * np.sin(beta) - - self._object[0] = self._euler_angle_rotate_volume( - self._object[0], - alpha_deg, - -beta_deg, - ) - - self._object[1] = self._euler_angle_rotate_volume( - self._object[1], - alpha_deg, - -beta_deg, - ) - - self._object[2] = self._euler_angle_rotate_volume( - self._object[2], - alpha_deg, - -beta_deg, - ) + weights = (1, weight_z, weight_x, weight_y) + for index, weight in zip(range(4), weights): + if collective_measurement_updates: + collective_object[index] += self._rotate_zxy_volume( + object_update * weight, + rot_matrix.T, + ) + else: + self._object[index] += object_update * weight - self._object[3] = self._euler_angle_rotate_volume( - self._object[3], - alpha_deg, - -beta_deg, - ) + old_rot_matrix = rot_matrix # Normalize Error - tilt_error /= ( - self._mean_diffraction_intensity[self._active_tilt_index] + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] * num_diffraction_patterns ) - error += tilt_error + error += measurement_error # constraints - self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[ + self._positions_px_all[start_idx:end_idx] = positions_px.copy()[ unshuffled_indices ] - if not collective_tilt_updates: + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[ + start_idx:end_idx + ] = self._positions_constraints( + self._positions_px_all[start_idx:end_idx], + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions ( self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], + _probe, + self._positions_px_all[start_idx:end_idx], ) = self._constraints( self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], + _probe, + self._positions_px_all[start_idx:end_idx], fix_com=fix_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, @@ -1374,8 +1239,9 @@ def reconstruct( and a0 >= fix_probe_iter, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, + initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1401,40 +1267,19 @@ def reconstruct( tv_denoise_inner_iter=tv_denoise_inner_iter, ) - # Normalize Error Over Tilts - error /= self._num_tilts + self._object = self._rotate_zxy_volume_vector( + self._object, old_rot_matrix.T + ) - self._object[1:] = self._divergence_free_constraint(self._object[1:]) + # Normalize Error Over Tilts + error /= self._num_measurements - if collective_tilt_updates: - self._object += collective_object / self._num_tilts + if collective_measurement_updates: + self._object += collective_object / self._num_measurements - ( + # object only + self._object = self._object_constraints_vector( self._object, - self._probe, - _, - ) = self._constraints( - self._object, - self._probe, - None, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=True, - global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, @@ -1458,6 +1303,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) From 75b2d5b4ee0d269476ad1767c0c5c7f7eac57316 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 7 Jan 2024 18:09:56 -0800 Subject: [PATCH 070/128] improved 3D visualizations Former-commit-id: 31b2daee4e436ad0e678d6220e41f9d31c4c2164 --- .../magnetic_ptychographic_tomography.py | 267 ++++++++++++++++-- .../process/phase/ptychographic_methods.py | 114 +++----- 2 files changed, 281 insertions(+), 100 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 3fceb6367..903a0363c 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -8,8 +8,13 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) try: import cupy as cp @@ -777,19 +782,6 @@ def _constraints(self, current_object, current_probe, current_positions, **kwarg return current_object, current_probe, current_positions - def _rotate_zxy_volume_vector( - self, - current_object, - rot_matrix, - ): - """ """ - for index in range(4): - current_object[index] = self._rotate_zxy_volume( - current_object[index], rot_matrix - ) - - return current_object - def reconstruct( self, max_iter: int = 8, @@ -1034,7 +1026,7 @@ def reconstruct( rot_matrix = self._tilt_orientation_matrices[ self._active_measurement_index ] - self._object = self._rotate_zxy_volume_vector( + self._object = self._rotate_zxy_volume_util( self._object, rot_matrix @ old_rot_matrix.T, ) @@ -1267,9 +1259,7 @@ def reconstruct( tv_denoise_inner_iter=tv_denoise_inner_iter, ) - self._object = self._rotate_zxy_volume_vector( - self._object, old_rot_matrix.T - ) + self._object = self._rotate_zxy_volume_util(self._object, old_rot_matrix.T) # Normalize Error Over Tilts error /= self._num_measurements @@ -1318,3 +1308,244 @@ def reconstruct( xp.clear_memo() return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + orientation_matrix=None, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + # get scaled arrays + + if orientation_matrix is not None: + ordered_obj = self._rotate_zxy_volume_vector( + self._object, + orientation_matrix, + ) + + # V(z,x,y), Ax(z,x,y), Ay(z,x,y), Az(z,x,y) + ordered_obj = asnumpy(ordered_obj) + ordered_obj[1:] = np.roll(ordered_obj[1:], -1, axis=0) + + else: + # V(z,x,y), Ax(z,x,y), Ay(z,x,y), Az(z,x,y) + ordered_obj = self.object.copy() + ordered_obj[1:] = np.roll(ordered_obj[1:], -1, axis=0) + + _, nz, nx, ny = ordered_obj.shape + img_array = np.zeros((nx + nx + nz, ny * 4), dtype=ordered_obj.dtype) + + axes = [1, 2, 0] + transposes = [False, True, False] + labels = [("z [A]", "y [A]"), ("x [A]", "z [A]"), ("x [A]", "y [A]")] + limits_v = [(0, nz), (nz, nz + nx), (nz + nx, nz + nx + nx)] + limits_h = [(0, ny), (0, nz), (0, ny)] + + titles = [ + [ + r"$V$ projected along $\hat{x}$", + r"$A_x$ projected along $\hat{x}$", + r"$A_y$ projected along $\hat{x}$", + r"$A_z$ projected along $\hat{x}$", + ], + [ + r"$V$ projected along $\hat{y}$", + r"$A_x$ projected along $\hat{y}$", + r"$A_y$ projected along $\hat{y}$", + r"$A_z$ projected along $\hat{y}$", + ], + [ + r"$V$ projected along $\hat{z}$", + r"$A_x$ projected along $\hat{z}$", + r"$A_y$ projected along $\hat{z}$", + r"$A_z$ projected along $\hat{z}$", + ], + ] + + for index in range(4): + for axis, transpose, limit_v, limit_h in zip( + axes, transposes, limits_v, limits_h + ): + start_v, end_v = limit_v + start_h, end_h = np.array(limit_h) + index * ny + + subarray = ordered_obj[index].sum(axis) + if transpose: + subarray = subarray.T + + img_array[start_v:end_v, start_h:end_h] = subarray + + if plot_convergence: + auto_figsize = (ny * 4 * 4 / nx, (nx + nx + nz) * 3.5 / nx + 1) + else: + auto_figsize = (ny * 4 * 4 / nx, (nx + nx + nz) * 3.5 / nx) + + figsize = kwargs.pop("figsize", auto_figsize) + cmap_e = kwargs.pop("cmap_e", "magma") + cmap_m = kwargs.pop("cmap_m", "PuOr") + vmin_e = kwargs.pop("vmin_e", None) + vmax_e = kwargs.pop("vmax_e", None) + + # remove common unused kwargs + kwargs.pop("plot_probe", None) + kwargs.pop("plot_fourier_probe", None) + kwargs.pop("remove_initial_probe_aberrations", None) + kwargs.pop("vertical_lims", None) + kwargs.pop("horizontal_lims", None) + + _, vmin_e, vmax_e = return_scaled_histogram_ordering( + img_array[:, :ny], vmin_e, vmax_e + ) + + _, _, _vmax_m = return_scaled_histogram_ordering(np.abs(img_array[:, ny:])) + vmin_m = kwargs.pop("vmin_m", -_vmax_m) + vmax_m = kwargs.pop("vmax_m", _vmax_m) + + if plot_convergence: + spec = GridSpec( + ncols=4, + nrows=4, + height_ratios=[nx, nz, nx, nx / 4], + hspace=0.15, + wspace=0.35, + ) + else: + spec = GridSpec( + ncols=4, nrows=3, height_ratios=[nx, nz, nx], hspace=0.15, wspace=0.35 + ) + + if fig is None: + fig = plt.figure(figsize=figsize) + + for sp in spec: + row, col = np.unravel_index(sp.num1, (4, 4)) + + if row < 3: + ax = fig.add_subplot(sp) + + start_v, end_v = limits_v[row] + start_h, end_h = np.array(limits_h[row]) + col * ny + subarray = img_array[start_v:end_v, start_h:end_h] + + extent = [ + 0, + self.sampling[1] * subarray.shape[1], + self.sampling[0] * subarray.shape[0], + 0, + ] + + im = ax.imshow( + subarray, + cmap=cmap_e if sp.is_first_col() else cmap_m, + vmin=vmin_e if sp.is_first_col() else vmin_m, + vmax=vmax_e if sp.is_first_col() else vmax_m, + extent=extent, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + ax.set_title(titles[row][col]) + + y_label, x_label = labels[row] + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[-1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _rotate_zxy_volume_util( + self, + current_object, + rot_matrix, + ): + """ """ + for index in range(4): + current_object[index] = self._rotate_zxy_volume( + current_object[index], rot_matrix + ) + + return current_object + + def _rotate_zxy_volume_vector(self, current_object, rot_matrix): + """Rotates vector field consistently. Note this is very expensive""" + + xp = self._xp + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + if xp is np or int(xp.__version__.split(".")[0]) < 12: + from scipy.interpolate import RegularGridInterpolator + + xp = np # ensure np is enforced for cupy < 12 + current_object = self._asnumpy(current_object) + else: + from cupyx.scipy.interpolate import RegularGridInterpolator + + _, nz, nx, ny = current_object.shape + + z, x, y = [xp.linspace(-1, 1, s, endpoint=False) for s in (nx, ny, nz)] + Z, X, Y = xp.meshgrid(z, x, y, indexing="ij") + coords = xp.array([Z.ravel(), X.ravel(), Y.ravel()]) + + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix @ swap_zxy_to_xyz) + rotated_vecs = tf.T.dot(coords).T + + Az = RegularGridInterpolator( + (z, x, y), current_object[1], bounds_error=False, fill_value=0 + ) + Ax = RegularGridInterpolator( + (z, x, y), current_object[2], bounds_error=False, fill_value=0 + ) + Ay = RegularGridInterpolator( + (z, x, y), current_object[3], bounds_error=False, fill_value=0 + ) + + xp = self._xp # switch back to device + obj = xp.zeros_like(current_object) + obj[0] = self._rotate_zxy_volume(xp.asarray(current_object[0]), rot_matrix) + + obj[1] = xp.asarray(Az(rotated_vecs).reshape(nz, nx, ny)) + obj[2] = xp.asarray(Ax(rotated_vecs).reshape(nz, nx, ny)) + obj[3] = xp.asarray(Ay(rotated_vecs).reshape(nz, nx, ny)) + + return obj diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index b809c3d1b..74ab621d3 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -218,10 +218,9 @@ def show_object_fft( vmax = kwargs.pop("vmax", 0.999) # remove additional 3D FFT parameters before passing to show - kwargs.pop("projection_angle_deg", None) - kwargs.pop("projection_axes", None) - kwargs.pop("x_lims", None) - kwargs.pop("y_lims", None) + kwargs.pop("orientation_matrix", None) + kwargs.pop("vertical_lims", None) + kwargs.pop("horizontal_lims", None) show( object_fft, @@ -873,44 +872,6 @@ def _initialize_object( return _object - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for compatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate(asnumpy(array), angle, reshape=False, axes=(-2, -1)) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - def _return_projected_cropped_potential( self, obj=None, @@ -919,26 +880,24 @@ def _return_projected_cropped_potential( ): """Utility function to accommodate multiple classes""" - projection_angle_deg = kwargs.pop("projection_angle_deg", None) - projection_axes = kwargs.pop("projection_axes", (0, 2)) - x_lims = kwargs.pop("x_lims", (None, None)) - y_lims = kwargs.pop("y_lims", (None, None)) + asnumpy = self._asnumpy + + rot_matrix = kwargs.pop("orientation_matrix", None) + v_lims = kwargs.pop("vertical_lims", (None, None)) + h_lims = kwargs.pop("horizontal_lims", (None, None)) if obj is None: obj = self._object - if projection_angle_deg is not None: - obj = self._rotate( + if rot_matrix is not None: + obj = self._rotate_zxy_volume( obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, + rot_matrix=rot_matrix, ) - obj = self._crop_rotate_object_manually( - obj, angle=None, x_lims=x_lims, y_lims=y_lims - ).sum(0) + start_v, end_v = v_lims + start_h, end_h = h_lims + obj = asnumpy(obj.sum(0)[start_v:end_v, start_h:end_h]) if return_kwargs: return obj, kwargs @@ -949,10 +908,9 @@ def _return_object_fft( self, obj=None, apply_hanning_window=False, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), + orientation_matrix=None, + vertical_lims: Tuple[int, int] = (None, None), + horizontal_lims: Tuple[int, int] = (None, None), **kwargs, ): """ @@ -964,14 +922,12 @@ def _return_object_fft( if None is specified, uses self._object apply_hanning_window: bool, optional If True, a 2D Hann window is applied to the object before FFT - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices + orientation_matrix: np.ndarray, optional + orientation matrix to rotate zone-axis + vertical_lims: tuple(int,int), optional + min/max vertical indices + horizontal_lims: tuple(int,int), optional + min/max horizontal indices Returns ------- @@ -987,29 +943,23 @@ def _return_object_fft( else: obj = xp.asarray(obj, dtype=xp.float32) - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( + if orientation_matrix is not None: + obj = self._rotate_zxy_volume( obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, + rot_matrix=orientation_matrix, ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) + start_v, end_v = vertical_lims + start_h, end_h = horizontal_lims + obj = asnumpy(obj.sum(0)[start_v:end_v, start_h:end_h]) if apply_hanning_window: - sx, sy = rotated_object.shape + sx, sy = obj.shape wx = np.hanning(sx) wy = np.hanning(sy) - rotated_object *= wx[:, None] * wy[None, :] + obj *= wx[:, None] * wy[None, :] - return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) @property def object_supersliced(self): From e764d01fe8210d152d0481fd7623471a1e6a79e6 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 7 Jan 2024 18:42:33 -0800 Subject: [PATCH 071/128] ragged list partitions support for complex plotting Former-commit-id: faef19ed43b5e3ed69d7d7099ede49098f9afa37 --- .../process/phase/ptychographic_methods.py | 19 +++++++++++++++---- py4DSTEM/process/phase/utils.py | 6 ++++++ py4DSTEM/visualize/vis_grid.py | 2 +- py4DSTEM/visualize/vis_special.py | 10 +++++----- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 74ab621d3..e8c52e781 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -10,6 +10,7 @@ ComplexProbe, fft_shift, generate_batches, + partition_list, rotate_point, spatial_frequencies, ) @@ -1158,6 +1159,7 @@ def show_probe( scalebar=True, pixelsize=None, pixelunits=None, + W=6, **kwargs, ): """ @@ -1173,10 +1175,12 @@ def show_probe( if True, adds colorbar scalebar: bool, optional if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 pixelsize: float, optional default is probe reciprocal sampling + pixelunits: str, optional + units for scalebar, default is A^-1 + W: int, optional + if not None, sets the width of the image grid """ asnumpy = self._asnumpy @@ -1210,6 +1214,8 @@ def show_probe( for pr in probe ] + probe = list(partition_list(probe, W)) + show_complex( probe, cbar=cbar, @@ -1231,6 +1237,7 @@ def show_fourier_probe( scalebar=True, pixelsize=None, pixelunits=None, + W=6, **kwargs, ): """ @@ -1246,10 +1253,12 @@ def show_fourier_probe( if True, adds colorbar scalebar: bool, optional if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 pixelsize: float, optional default is probe reciprocal sampling + pixelunits: str, optional + units for scalebar, default is A^-1 + W: int, optional + if not None, sets the width of the image grid """ asnumpy = self._asnumpy @@ -1288,6 +1297,8 @@ def show_fourier_probe( for pr in probe ] + probe = list(partition_list(probe, W)) + show_complex( probe, cbar=cbar, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index fc1715ab2..8894faaf3 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -2275,3 +2275,9 @@ def vectorized_fourier_resample( array_resize *= scale_output return array_resize + + +def partition_list(lst, size): + """Partitions lst into chunks of size. Returns a generator.""" + for i in range(0, len(lst), size): + yield lst[i : i + size] diff --git a/py4DSTEM/visualize/vis_grid.py b/py4DSTEM/visualize/vis_grid.py index d24b0b8d8..9be754689 100644 --- a/py4DSTEM/visualize/vis_grid.py +++ b/py4DSTEM/visualize/vis_grid.py @@ -205,7 +205,7 @@ def show_image_grid( ax = axs[i, j] N = i * W + j # make titles - if type(title) == list: + if type(title) == list and N < len(title): print_title = title[N] else: print_title = None diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 9ddfab372..c8b8a8b12 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -801,8 +801,8 @@ def show_complex( ) if scalebar is True: scalebar = { - "Nx": ar_complex[0].shape[0], - "Ny": ar_complex[0].shape[1], + "Nx": rgb[0].shape[0], + "Ny": rgb[0].shape[1], "pixelsize": pixelsize, "pixelunits": pixelunits, } @@ -824,8 +824,8 @@ def show_complex( if scalebar is True: scalebar = { - "Nx": ar_complex.shape[0], - "Ny": ar_complex.shape[1], + "Nx": rgb.shape[0], + "Ny": rgb.shape[1], "pixelsize": pixelsize, "pixelunits": pixelunits, } @@ -835,7 +835,7 @@ def show_complex( # add color bar if cbar: if is_grid: - for ax_flat in ax.flatten(): + for ax_flat in ax.flatten()[: len(rgb)]: divider = make_axes_locatable(ax_flat) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) From 706abe491e733b9fd2f44e3b56267269a107c5cb Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 7 Jan 2024 22:30:32 -0800 Subject: [PATCH 072/128] added detector plane resampling Former-commit-id: c9156dc9123f61aebc79ac3a5f50dd720ef5a17a --- py4DSTEM/preprocess/preprocess.py | 4 +- .../magnetic_ptychographic_tomography.py | 31 ++++++-- .../process/phase/magnetic_ptychography.py | 33 +++++--- .../mixedstate_multislice_ptychography.py | 22 ++++-- .../process/phase/mixedstate_ptychography.py | 17 +++- .../process/phase/multislice_ptychography.py | 22 ++++-- py4DSTEM/process/phase/phase_base_class.py | 21 +++-- .../process/phase/ptychographic_methods.py | 78 +++++++++++++++---- .../process/phase/ptychographic_tomography.py | 31 ++++++-- .../phase/ptychographic_visualizations.py | 1 + .../process/phase/singleslice_ptychography.py | 22 ++++-- py4DSTEM/process/phase/utils.py | 64 +++++++++++++++ 12 files changed, 274 insertions(+), 72 deletions(-) diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index fb4983622..9db7895d3 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -576,7 +576,9 @@ def resample_data_diffraction( resampling_factor = np.array(output_size) / np.array(datacube.shape[-2:]) resampling_factor = np.concatenate(((1, 1), resampling_factor)) - datacube.data = zoom(datacube.data, resampling_factor, order=1) + datacube.data = zoom( + datacube.data, resampling_factor, order=1, mode="grid-wrap", grid_mode=True + ) datacube.calibration.set_Q_pixel_size( datacube.calibration.get_Q_pixel_size() / resampling_factor[2] ) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 903a0363c..2d98b6834 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -236,7 +236,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_probe_overlaps: bool = True, @@ -265,9 +266,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -308,7 +312,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -350,11 +356,20 @@ def preprocess( roi_shape = self._datacube[0].Qshape if diffraction_intensities_shape is not None: roi_shape = diffraction_intensities_shape - if probe_roi_shape is not None: - roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) - self._region_of_interest_shape = np.array(roi_shape) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # TO-DO: generalize this if force_com_shifts is None: @@ -381,7 +396,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts[index], @@ -397,7 +412,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=None, dp_mask=None, com_shifts=force_com_shifts[index], diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 0e633ae71..5333960e9 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -207,7 +207,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_rotation: bool = True, @@ -244,10 +245,13 @@ def preprocess( Pixel dimensions (Qx',Qy') of the resampled diffraction intensities If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional - Method to use for reshaping, either 'bin', 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -291,7 +295,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -374,11 +380,20 @@ def preprocess( roi_shape = self._datacube[0].Qshape if diffraction_intensities_shape is not None: roi_shape = diffraction_intensities_shape - if probe_roi_shape is not None: - roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) - self._region_of_interest_shape = np.array(roi_shape) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # TO-DO: generalize this if force_com_shifts is None: @@ -408,7 +423,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts[index], @@ -424,7 +439,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=None, dp_mask=None, com_shifts=force_com_shifts[index], diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 931985224..630608013 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -276,7 +276,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -314,9 +315,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -363,7 +367,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -387,7 +393,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -460,7 +466,13 @@ def preprocess( # explicitly delete namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index dbcb62e97..177898a0f 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -217,7 +217,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -304,7 +305,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -328,7 +331,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -401,7 +404,13 @@ def preprocess( # explicitly delete namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 69ad11330..3dd5c34e2 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -251,7 +251,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -289,9 +290,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -338,7 +342,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -362,7 +368,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -435,7 +441,13 @@ def preprocess( # explicitly delete namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 7109e89a1..546dd7c38 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -136,7 +136,7 @@ def _preprocess_datacube_and_vacuum_probe( datacube, diffraction_intensities_shape=None, reshaping_method="fourier", - probe_roi_shape=None, + padded_diffraction_intensities_shape=None, vacuum_probe_intensity=None, dp_mask=None, com_shifts=None, @@ -153,13 +153,10 @@ def _preprocess_datacube_and_vacuum_probe( Note this does not affect the maximum scattering wavevector (Qx*dkx,Qy*dky) = (Sx*dkx',Sy*dky'), and thus the real-space sampling stays fixed. - The real space sampling, (dx, dy), combined with the resampled diffraction_intensities_shape, - sets the real-space probe region of interest (ROI) extent (dx*Sx, dy*Sy). - Occasionally, one may also want to specify a larger probe ROI extent, e.g when the probe - does not comfortably fit without self-ovelap artifacts, or when the scan step sizes are much - smaller than the real-space sampling (dx,dy). This can be achieved by specifying a - probe_roi_shape, which is larger than diffraction_intensities_shape, which will result in - zero-padding of the diffraction intensities. + Additionally, one may wish to zero-pad the diffraction intensity data. Note this does not increase + the information or resolution, but might be beneficial in a limited number of cases, e.g. when the + scan step sizes are much smaller than the real-space sampling (dx,dy). This can be achieved by specifying + a padded_diffraction_intensities_shape which is larger than diffraction_intensities_shape. Parameters ---------- @@ -170,7 +167,7 @@ def _preprocess_datacube_and_vacuum_probe( If None, no resamping is performed reshaping method: str, optional Reshaping method to use, one of 'bin', 'bilinear' or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape, (int,int), optional Padded diffraction intensities shape. If None, no padding is performed vacuum_probe_intensity, np.ndarray, optional @@ -284,10 +281,10 @@ def _preprocess_datacube_and_vacuum_probe( ) ) - if probe_roi_shape is not None: + if padded_diffraction_intensities_shape is not None: Qx, Qy = datacube.shape[-2:] - Sx, Sy = probe_roi_shape - datacube = datacube.pad_Q(output_size=probe_roi_shape) + Sx, Sy = padded_diffraction_intensities_shape + datacube = datacube.pad_Q(output_size=padded_diffraction_intensities_shape) if vacuum_probe_intensity is not None or dp_mask is not None: pad_kx = Sx - Qx diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index e8c52e781..81b6c7946 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -13,6 +13,7 @@ partition_list, rotate_point, spatial_frequencies, + vectorized_bilinear_resample, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex @@ -1215,6 +1216,7 @@ def show_probe( ] probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] show_complex( probe, @@ -1298,6 +1300,7 @@ def show_fourier_probe( ] probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] show_complex( probe, @@ -1497,15 +1500,28 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): """ xp = self._xp + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) exit_waves = modified_overlap - overlap + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + exit_waves = vectorized_bilinear_resample( + exit_waves, output_size=self._region_of_interest_shape, xp=xp + ) + return exit_waves, error def _projection_sets_fourier_projection( @@ -1553,18 +1569,30 @@ def _projection_sets_fourier_projection( if exit_waves is None: exit_waves = overlap.copy() - fourier_overlap = xp.fft.fft2(overlap) - farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) - error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + factor_to_be_projected = vectorized_bilinear_resample( + factor_to_be_projected, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) fourier_projected_factor = amplitudes * xp.exp( 1j * xp.angle(fourier_projected_factor) ) + projected_factor = xp.fft.ifft2(fourier_projected_factor) + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + projected_factor = vectorized_bilinear_resample( + projected_factor, output_size=self._region_of_interest_shape, xp=xp + ) + exit_waves = ( projection_x * exit_waves + projection_a * overlap @@ -2503,6 +2531,13 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): """ xp = self._xp + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) @@ -2515,6 +2550,12 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap): exit_waves = modified_overlap - overlap + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + exit_waves = vectorized_bilinear_resample( + exit_waves, output_size=self._region_of_interest_shape, xp=xp + ) + return exit_waves, error def _projection_sets_fourier_projection( @@ -2562,23 +2603,30 @@ def _projection_sets_fourier_projection( if exit_waves is None: exit_waves = overlap.copy() - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + factor_to_be_projected = vectorized_bilinear_resample( + factor_to_be_projected, output_size=amplitudes.shape[-2:], xp=xp + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes - amplitude_modification = amplitudes / intensity_norm_projected fourier_projected_factor *= amplitude_modification[:, None] - projected_factor = xp.fft.ifft2(fourier_projected_factor) + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + projected_factor = vectorized_bilinear_resample( + projected_factor, output_size=self._region_of_interest_shape, xp=xp + ) + exit_waves = ( projection_x * exit_waves + projection_a * overlap diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 7d4964fa4..13bdf035d 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -230,7 +230,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_probe_overlaps: bool = True, @@ -260,9 +261,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -307,7 +311,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -349,11 +355,20 @@ def preprocess( roi_shape = self._datacube[0].Qshape if diffraction_intensities_shape is not None: roi_shape = diffraction_intensities_shape - if probe_roi_shape is not None: - roi_shape = tuple(max(q, s) for q, s in zip(roi_shape, probe_roi_shape)) + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) - self._region_of_interest_shape = np.array(roi_shape) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # TO-DO: generalize this if force_com_shifts is None: @@ -380,7 +395,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts[index], @@ -396,7 +411,7 @@ def preprocess( self._datacube[index], diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=None, dp_mask=None, com_shifts=force_com_shifts[index], diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py index 916ef0352..014c380f0 100644 --- a/py4DSTEM/process/phase/ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -284,6 +284,7 @@ def _visualize_all_iterations( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, **kwargs, ) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 57b8ec93c..59535a19d 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -195,7 +195,8 @@ def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, dp_mask: np.ndarray = None, fit_function: str = "plane", plot_center_of_mass: str = "default", @@ -233,9 +234,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -282,7 +286,9 @@ def preprocess( # set additional metadata self._diffraction_intensities_shape = diffraction_intensities_shape self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) self._dp_mask = dp_mask if self._datacube is None: @@ -306,7 +312,7 @@ def preprocess( self._datacube, diffraction_intensities_shape=self._diffraction_intensities_shape, reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, @@ -379,7 +385,13 @@ def preprocess( # explicitly delete intensities namespace self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) del self._intensities # initialize probe positions diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 8894faaf3..c9db5f86b 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -4,12 +4,14 @@ import matplotlib.pyplot as plt import numpy as np from scipy.fft import dctn, idctn +from scipy.ndimage import zoom from scipy.optimize import curve_fit try: import cupy as cp from cupyx.scipy.fft import dctn as dctn_cp from cupyx.scipy.fft import idctn as idctn_cp + from cupyx.scipy.ndimage import zoom as zoom_cp except (ImportError, ModuleNotFoundError): cp = None @@ -2108,6 +2110,67 @@ def lanczos_kernel_density_estimate( return pix_output +def vectorized_bilinear_resample( + array, + scale=None, + output_size=None, + mode="grid-wrap", + grid_mode=True, + xp=np, +): + """ + Resize an array along its final two axes. + Note, this is vectorized and thus very memory-intensive. + + The scaling of the array can be specified by passing either `scale`, which sets + the scaling factor along both axes to be scaled; or by passing `output_size`, + which specifies the final dimensions of the scaled axes. + + Parameters + ---------- + array: np.ndarray + Input array to be resampled + scale: float + Scalar value giving the scaling factor for all dimensions + output_size: (int,int) + Tuple of two values giving the output size for the final two axes + xp: Callable + Array computing module + + Returns + ------- + resampled_array: np.ndarray + Resampled array + """ + + array_size = np.array(array.shape) + input_size = array_size[-2:].copy() + + if scale is not None: + scale = np.array(scale) + if scale.size == 1: + scale = np.tile(scale, 2) + + output_size = (input_size * scale).astype("int") + else: + if output_size is None: + raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) + if output_size.size != 2: + raise ValueError("`output_size` must contain exactly two values.") + output_size = np.array(output_size) + + scale_output = tuple(output_size / input_size) + scale_output = (1,) * (array_size.size - input_size.size) + scale_output + + if xp is np: + array = zoom(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + else: + array = zoom_cp(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + + return array + + def vectorized_fourier_resample( array, scale=None, @@ -2153,6 +2216,7 @@ def vectorized_fourier_resample( else: if output_size is None: raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) if output_size.size != 2: raise ValueError("`output_size` must contain exactly two values.") output_size = np.array(output_size) From 06776ce53a43217d140b5aeaa26fd590dcb82420 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 9 Jan 2024 16:22:02 -0800 Subject: [PATCH 073/128] cleaned up cupyx.scipy imports Former-commit-id: 6470fd8f32f8a14aa0a2fa12446ef217ec210562 --- py4DSTEM/process/phase/dpc.py | 12 +++--- .../magnetic_ptychographic_tomography.py | 30 ++++---------- .../process/phase/magnetic_ptychography.py | 19 ++++----- .../mixedstate_multislice_ptychography.py | 41 +++++++++---------- .../process/phase/mixedstate_ptychography.py | 41 +++++++++---------- .../process/phase/multislice_ptychography.py | 19 ++++----- py4DSTEM/process/phase/parallax.py | 14 ++++--- py4DSTEM/process/phase/phase_base_class.py | 16 +++----- .../phase/ptychographic_constraints.py | 4 +- .../process/phase/ptychographic_methods.py | 2 +- .../process/phase/ptychographic_tomography.py | 30 ++++---------- .../phase/ptychographic_visualizations.py | 2 +- .../process/phase/singleslice_ptychography.py | 19 ++++----- 13 files changed, 103 insertions(+), 146 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index 7043afc9b..f379e50d3 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -59,17 +59,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter + self._scipy = scipy - self._gaussian_filter = gaussian_filter elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + self._scipy = scipy - self._gaussian_filter = gaussian_filter else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -516,7 +518,7 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): constrained_object: np.ndarray Constrained object estimate """ - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter gaussian_filter_sigma /= self.sampling[0] current_object = gaussian_filter(current_object, gaussian_filter_sigma) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 2d98b6834..2c9e8f0b2 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -156,34 +156,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from scipy.special import erf + self._scipy = scipy - self._erf = erf elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import ( - affine_transform, - gaussian_filter, - rotate, - zoom, - ) - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from cupyx.scipy.special import erf + self._scipy = scipy - self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -601,7 +586,8 @@ def preprocess( old_rot_matrix.T, ) - probe_overlap_3D_blurred = self._gaussian_filter(probe_overlap_3D, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_3D_blurred = gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max() ) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 5333960e9..2bf556d21 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -142,23 +142,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf + self._scipy = scipy - self._erf = erf elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf + self._scipy = scipy - self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -623,7 +619,8 @@ def preprocess( # initialize object_fov_mask if object_fov_mask is None: - probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 630608013..a2372a772 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -161,6 +161,23 @@ def __init__( ): Custom.__init__(self, name=name) + if device == "cpu": + import scipy + + self._xp = np + self._asnumpy = np.asarray + self._scipy = scipy + + elif device == "gpu": + from cupyx import scipy + + self._xp = cp + self._asnumpy = cp.asnumpy + self._scipy = scipy + + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): if num_probes is None: raise ValueError( @@ -176,27 +193,6 @@ def __init__( ) num_probes = initial_probe_guess.shape[0] - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) @@ -554,7 +550,8 @@ def preprocess( probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) if object_fov_mask is None: - probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 177898a0f..9ebe802ba 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -136,6 +136,23 @@ def __init__( ): Custom.__init__(self, name=name) + if device == "cpu": + import scipy + + self._xp = np + self._asnumpy = np.asarray + self._scipy = scipy + + elif device == "gpu": + from cupyx import scipy + + self._xp = cp + self._asnumpy = cp.asnumpy + self._scipy = scipy + + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): if num_probes is None: raise ValueError( @@ -151,27 +168,6 @@ def __init__( ) num_probes = initial_probe_guess.shape[0] - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) @@ -481,7 +477,8 @@ def preprocess( probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) if object_fov_mask is None: - probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 3dd5c34e2..8b6a267a6 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -153,23 +153,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf + self._scipy = scipy - self._erf = erf elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf + self._scipy = scipy - self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -529,7 +525,8 @@ def preprocess( probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) if object_fov_mask is None: - probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 0f32470d6..ee1c6f8b5 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -87,17 +87,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter + self._scipy = scipy - self._gaussian_filter = gaussian_filter elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + self._scipy = scipy - self._gaussian_filter = gaussian_filter else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -1109,7 +1111,7 @@ def subpixel_alignment( """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter BF_sampling = 1 / asnumpy(self._kr).max() / 2 DF_sampling = 1 / ( @@ -1774,7 +1776,7 @@ def _kernel_density_estimate( """ """ xp = self._xp - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter if lanczos_alpha is not None: return lanczos_kernel_density_estimate( diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 546dd7c38..0c78fc5bd 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -73,23 +73,19 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): if device is not None: if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf + self._scipy = scipy - self._erf = erf elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf + self._scipy = scipy - self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") self._device = device diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 8d8a5d158..9f19bf835 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -138,7 +138,7 @@ def _object_gaussian_constraint( Constrained object estimate """ xp = self._xp - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter gaussian_filter_sigma /= self.sampling[0] if not pure_phase_object or self._object_type == "potential": @@ -938,7 +938,7 @@ def _probe_amplitude_constraint( Constrained probe estimate """ xp = self._xp - erf = self._erf + erf = self._scipy.special.erf probe_intensity = xp.abs(current_probe) ** 2 current_probe_sum = xp.sum(probe_intensity) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 81b6c7946..3d4d5bb00 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -828,7 +828,7 @@ def _rotate_zxy_volume( """ """ xp = self._xp - affine_transform = self._affine_transform + affine_transform = self._scipy.ndimage.affine_transform swap_zxy_to_xyz = self._swap_zxy_to_xyz volume = volume_array.copy() diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 13bdf035d..78444b39b 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -150,34 +150,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from scipy.special import erf + self._scipy = scipy - self._erf = erf elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import ( - affine_transform, - gaussian_filter, - rotate, - zoom, - ) - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from cupyx.scipy.special import erf + self._scipy = scipy - self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -600,7 +585,8 @@ def preprocess( old_rot_matrix.T, ) - probe_overlap_3D_blurred = self._gaussian_filter(probe_overlap_3D, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_3D_blurred = gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max() ) diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py index 014c380f0..ab783cc29 100644 --- a/py4DSTEM/process/phase/ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -662,7 +662,7 @@ def show_uncertainty_visualization( xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter if errors is None: errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 59535a19d..13151f02a 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -131,23 +131,19 @@ def __init__( Custom.__init__(self, name=name) if device == "cpu": + import scipy + self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf + self._scipy = scipy - self._erf = erf elif device == "gpu": + from cupyx import scipy + self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf + self._scipy = scipy - self._erf = erf else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") @@ -463,7 +459,8 @@ def preprocess( # initialize object_fov_mask if object_fov_mask is None: - probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0) + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) From 7cbb432050f833035a37e0d1bdc621567e3a21f8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 10 Jan 2024 22:03:02 -0800 Subject: [PATCH 074/128] added storage support for singleslice. other classes likely broken, will fix tomorrow Former-commit-id: 951162e2b10d934db47af83acaa8321504f81e15 --- py4DSTEM/process/phase/phase_base_class.py | 176 ++++++++++++---- .../phase/ptychographic_constraints.py | 42 ++-- .../process/phase/ptychographic_methods.py | 143 ++++++++----- .../phase/ptychographic_visualizations.py | 10 +- .../process/phase/singleslice_ptychography.py | 195 ++++++++++-------- py4DSTEM/process/phase/utils.py | 52 +++-- 6 files changed, 397 insertions(+), 221 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 0c78fc5bd..2ce1d5bec 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -12,6 +12,7 @@ try: import cupy as cp + from cupy.fft.config import get_plan_cache except (ModuleNotFoundError, ImportError): cp = np @@ -19,7 +20,12 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin -from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases +from py4DSTEM.process.phase.utils import ( + AffineTransform, + copy_to_device, + get_array_module, + polar_aliases, +) from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -35,6 +41,81 @@ class PhaseReconstruction(Custom): Defines various common functions and properties for subclasses to inherit. """ + def set_device(self, device, clear_fft_cache): + """ + Sets calculation device. + + Parameters + ---------- + device: str + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device == "cpu": + import scipy + + self._xp = np + self._scipy = scipy + + elif device == "gpu": + from cupyx import scipy + + self._xp = cp + self._scipy = scipy + + elif device is not None: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + self._clear_fft_cache = clear_fft_cache + + return self + + def set_storage(self, storage): + """ + Sets storage device. + + Parameters + ---------- + storage: str + Device arrays will be stored on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if storage == "cpu": + self._xp_storage = np + + elif storage == "gpu": + if self._xp is np: + raise ValueError("storage='gpu' and device='cpu' is not supported") + self._xp_storage = cp + + else: + raise ValueError(f"storage must be either 'cpu' or 'gpu', not {storage}") + + self._asnumpy = copy_to_device + + return self + + def clear_device_mem(self, device, clear_fft_cache): + """ """ + if device == "gpu": + if clear_fft_cache is True: + cache = get_plan_cache() + cache.clear() + + xp = self._xp + xp._default_memory_pool.free_all_blocks() + xp._default_pinned_memory_pool.free_all_blocks() + def attach_datacube(self, datacube: DataCube): """ Attaches a datacube to a class initialized without one. @@ -357,11 +438,10 @@ def _extract_intensities_and_calibrations_from_datacube( """ # explicit read-only self attributes up-front - xp = self._xp verbose = self._verbose energy = self._energy - intensities = xp.asarray(datacube.data, dtype=xp.float32) + intensities = np.asarray(datacube.data, dtype=np.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -531,13 +611,18 @@ def _calculate_intensities_center_of_mass( # explicit read-only self attributes up-front xp = self._xp + device = self._device asnumpy = self._asnumpy + reciprocal_sampling = self._reciprocal_sampling if com_measured: com_measured_x, com_measured_y = com_measured else: + # copy to device + intensities = copy_to_device(intensities, device) + # Coordinates kx = xp.arange(intensities.shape[-2], dtype=xp.float32) ky = xp.arange(intensities.shape[-1], dtype=xp.float32) @@ -1087,10 +1172,6 @@ def _solve_for_center_of_mass_relative_rotation( + xp.cos(_rotation_best_rad) * _com_normalized_y ) - # 'Public'-facing attributes as numpy arrays - com_x = asnumpy(_com_x) - com_y = asnumpy(_com_y) - # Optionally, plot CoM if plot_center_of_mass == "all": figsize = kwargs.pop("figsize", (8, 12)) @@ -1112,8 +1193,8 @@ def _solve_for_center_of_mass_relative_rotation( _com_measured_y, _com_normalized_x, _com_normalized_y, - com_x, - com_y, + _com_x, + _com_y, ], [ "CoM_x", @@ -1135,8 +1216,8 @@ def _solve_for_center_of_mass_relative_rotation( extent = [ 0, - scan_sampling[1] * com_x.shape[1], - scan_sampling[0] * com_x.shape[0], + scan_sampling[1] * _com_x.shape[1], + scan_sampling[0] * _com_x.shape[0], 0, ] @@ -1146,15 +1227,15 @@ def _solve_for_center_of_mass_relative_rotation( for ax, arr, title in zip( grid, [ - com_x, - com_y, + _com_x, + _com_y, ], [ "Corrected CoM_x", "Corrected CoM_y", ], ): - ax.imshow(arr, extent=extent, cmap=cmap, **kwargs) + ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs) ax.set_ylabel(f"x [{scan_units[0]}]") ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) @@ -1164,8 +1245,6 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose, _com_x, _com_y, - com_x, - com_y, ) def _normalize_diffraction_intensities( @@ -1202,12 +1281,14 @@ def _normalize_diffraction_intensities( """ # explicit read-only self attributes up-front - xp = self._xp asnumpy = self._asnumpy mean_intensity = 0 diffraction_intensities = asnumpy(diffraction_intensities) + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + if positions_mask is not None: number_of_patterns = np.count_nonzero(positions_mask.ravel()) else: @@ -1252,9 +1333,6 @@ def _normalize_diffraction_intensities( (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 ) - com_fitted_x = asnumpy(com_fitted_x) - com_fitted_y = asnumpy(com_fitted_y) - counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): @@ -1278,7 +1356,6 @@ def _normalize_diffraction_intensities( amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) counter += 1 - amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity, crop_mask @@ -1311,11 +1388,12 @@ def show_complex_CoM( """ # explicit read-only self attributes up-front + asnumpy = self._asnumpy scan_sampling = self._scan_sampling scan_units = self._scan_units if com is None: - com = (self.com_x, self.com_y) + com = (self._com_x, self._com_y) if pixelsize is None: pixelsize = scan_sampling[0] @@ -1328,7 +1406,7 @@ def show_complex_CoM( complex_com = com[0] + 1j * com[1] show_complex( - complex_com, + asnumpy(complex_com), cbar=cbar, figax=(fig, ax), scalebar=scalebar, @@ -1699,7 +1777,9 @@ def _calculate_scan_positions_in_pixels( return positions, object_padding_px - def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): + def _sum_overlapping_patches_bincounts_base( + self, patches: np.ndarray, positions_px + ): """ Base bincouts overlapping patches sum function, operating on real-valued arrays. Note this assumes the probe is corner-centered. @@ -1715,8 +1795,7 @@ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): Summed array """ # explicit read-only self attributes up-front - xp = self._xp - positions_px = self._positions_px + xp = get_array_module(patches) roi_shape = self._region_of_interest_shape object_shape = self._object_shape @@ -1735,7 +1814,7 @@ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): ) return xp.reshape(counts, object_shape) - def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): + def _sum_overlapping_patches_bincounts(self, patches: np.ndarray, positions_px): """ Sum overlapping patches defined into object shaped array using bincounts. Calls _sum_overlapping_patches_bincounts_base on Real and Imaginary parts. @@ -1751,15 +1830,21 @@ def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): Summed array """ - xp = self._xp - if xp.iscomplexobj(patches): - real = self._sum_overlapping_patches_bincounts_base(xp.real(patches)) - imag = self._sum_overlapping_patches_bincounts_base(xp.imag(patches)) + if np.iscomplexobj(patches): + real = self._sum_overlapping_patches_bincounts_base( + patches.real, positions_px + ) + imag = self._sum_overlapping_patches_bincounts_base( + patches.imag, positions_px + ) return real + 1.0j * imag else: - return self._sum_overlapping_patches_bincounts_base(patches) + return self._sum_overlapping_patches_bincounts_base(patches, positions_px) - def _extract_vectorized_patch_indices(self): + def _extract_vectorized_patch_indices( + self, + positions_px, + ): """ Sets the vectorized row/col indices used for the overlap projection Note this assumes the probe is corner-centered. @@ -1772,17 +1857,15 @@ def _extract_vectorized_patch_indices(self): Column indices for probe patches inside object array """ # explicit read-only self attributes up-front - xp = self._xp - positions_px = self._positions_px - positions_px = self._positions_px + xp_storage = self._xp_storage roi_shape = self._region_of_interest_shape obj_shape = self._object_shape - x0 = xp.round(positions_px[:, 0]).astype("int") - y0 = xp.round(positions_px[:, 1]).astype("int") + x0 = xp_storage.round(positions_px[:, 0]).astype("int") + y0 = xp_storage.round(positions_px[:, 1]).astype("int") - x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") - y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") + x_ind = xp_storage.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") + y_ind = xp_storage.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") vectorized_patch_indices_row = ( x0[:, None, None] + x_ind[None, :, None] @@ -1973,12 +2056,21 @@ def _report_reconstruction_summary( ) ) - def _constraints(self, current_object, current_probe, current_positions, **kwargs): + def _constraints( + self, + current_object, + current_probe, + current_positions, + initial_positions, + **kwargs, + ): """Wrapper function for all classes to inherit""" current_object = self._object_constraints(current_object, **kwargs) current_probe = self._probe_constraints(current_probe, **kwargs) - current_positions = self._positions_constraints(current_positions, **kwargs) + current_positions = self._positions_constraints( + current_positions, initial_positions, **kwargs + ) return current_object, current_probe, current_positions diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 9f19bf835..eb6e95444 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1089,7 +1089,7 @@ def _probe_aberration_fitting_constraint( def _probe_constraints( self, current_probe, - fix_com, + fix_probe_com, fit_probe_aberrations, fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, @@ -1107,7 +1107,7 @@ def _probe_constraints( """ProbeConstraints wrapper function""" # CoM corner-centering - if fix_com: + if fix_probe_com: current_probe = self._probe_center_of_mass_constraint(current_probe) # Fourier phase (aberrations) fitting @@ -1280,7 +1280,9 @@ class PositionsConstraintsMixin: Mixin class for probe positions constraints. """ - def _positions_center_of_mass_constraint(self, current_positions): + def _positions_center_of_mass_constraint( + self, current_positions, initial_positions_com + ): """ Ptychographic position center of mass constraint. Additionally updates vectorized indices used in _overlap_projection. @@ -1295,15 +1297,7 @@ def _positions_center_of_mass_constraint(self, current_positions): constrained_positions: np.ndarray CoM constrained positions estimate """ - xp = self._xp - - current_positions -= xp.mean(current_positions, axis=0) - self._positions_px_com - self._positions_px_fractional = current_positions - xp.round(current_positions) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + current_positions -= current_positions.mean(0) - initial_positions_com return current_positions @@ -1330,39 +1324,45 @@ def _positions_affine_transformation_constraint( Affine-transform constrained positions estimate """ - xp = self._xp + xp_storage = self._xp_storage + initial_positions_com = initial_positions.mean(0) tf, _ = estimate_global_transformation_ransac( positions0=initial_positions, positions1=current_positions, - origin=self._positions_px_com, + origin=initial_positions_com, translation_allowed=True, - min_sample=self._num_diffraction_patterns // 10, - xp=xp, + min_sample=initial_positions.shape[0] // 10, + xp=xp_storage, ) + current_positions = tf( + initial_positions, origin=initial_positions_com, xp=xp_storage + ) self._tf = tf - current_positions = tf(initial_positions, origin=self._positions_px_com, xp=xp) return current_positions def _positions_constraints( self, current_positions, + initial_positions, fix_positions, + fix_positions_com, global_affine_transformation, **kwargs, ): """PositionsConstraints wrapper function""" if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) + if not fix_positions_com: + current_positions = self._positions_center_of_mass_constraint( + current_positions, initial_positions.mean(0) + ) if global_affine_transformation: current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions + initial_positions, current_positions ) return current_positions diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 3d4d5bb00..c59221166 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -8,6 +8,7 @@ from py4DSTEM.process.phase.utils import ( AffineTransform, ComplexProbe, + copy_to_device, fft_shift, generate_batches, partition_list, @@ -41,6 +42,7 @@ def _initialize_object( """ """ # explicit read-only self attributes up-front xp = self._xp + object_padding_px = self._object_padding_px region_of_interest_shape = self._region_of_interest_shape @@ -65,6 +67,7 @@ def _initialize_object( def _crop_rotate_object_fov( self, array, + positions_px=None, padding=0, ): """ @@ -83,16 +86,20 @@ def _crop_rotate_object_fov( """ asnumpy = self._asnumpy + angle = ( self._rotation_best_rad if self._rotation_best_transpose else -self._rotation_best_rad ) + if positions_px is None: + positions_px = asnumpy(self._positions_px) + else: + positions_px = asnumpy(positions_px) + tf = AffineTransform(angle=angle) - rotated_points = tf( - asnumpy(self._positions_px), origin=asnumpy(self._positions_px_com), xp=np - ) + rotated_points = tf(positions_px, origin=positions_px.mean(0), xp=np) min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") min_x = min_x if min_x > 0 else 0 @@ -986,6 +993,7 @@ def _initialize_probe( # explicit read-only self attributes up-front xp = self._xp device = self._device + crop_mask = self._crop_mask region_of_interest_shape = self._region_of_interest_shape sampling = self.sampling @@ -1431,14 +1439,20 @@ class ObjectNDProbeMethodsMixin: Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. """ - def _return_shifted_probes(self, current_probe): - """Simple utlity to de-duplicate _overlap_projection""" + def _return_shifted_probes(self, current_probe, positions_px_fractional): + """Simple utility to de-duplicate _overlap_projection""" xp = self._xp - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) + shifted_probes = fft_shift(current_probe, positions_px_fractional, xp) return shifted_probes - def _overlap_projection(self, current_object, shifted_probes): + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): """ Ptychographic overlap projection method. @@ -1461,15 +1475,13 @@ def _overlap_projection(self, current_object, shifted_probes): xp = self._xp - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + object_patches = current_object[ + vectorized_patch_indices_row, vectorized_patch_indices_col ] + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + overlap = shifted_probes * object_patches return shifted_probes, object_patches, overlap @@ -1604,7 +1616,10 @@ def _projection_sets_fourier_projection( def _forward( self, current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, current_probe, + positions_px_fractional, amplitudes, exit_waves, use_projection_scheme, @@ -1645,10 +1660,15 @@ def _forward( error: float Reconstruction error """ + shifted_probes = self._return_shifted_probes( + current_probe, positions_px_fractional + ) - shifted_probes = self._return_shifted_probes(current_probe) shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, shifted_probes + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, ) if use_projection_scheme: @@ -1674,6 +1694,7 @@ def _gradient_descent_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, step_size, normalization_min, @@ -1685,10 +1706,6 @@ def _gradient_descent_adjoint( Parameters -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate object_patches: np.ndarray Patched object view shifted_probes:np.ndarray @@ -1704,15 +1721,16 @@ def _gradient_descent_adjoint( Returns -------- - updated_object: np.ndarray + object_update: np.ndarray Updated object estimate - updated_probe: np.ndarray + probe_update: np.ndarray Updated probe estimate """ xp = self._xp probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 + xp.abs(shifted_probes) ** 2, + positions_px, ) probe_normalization = 1 / xp.sqrt( 1e-16 @@ -1728,14 +1746,15 @@ def _gradient_descent_adjoint( * xp.conj(object_patches) * xp.conj(shifted_probes) * exit_waves - ) + ), + positions_px, ) * probe_normalization ) else: current_object += step_size * ( self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves + xp.conj(shifted_probes) * exit_waves, positions_px ) * probe_normalization ) @@ -1777,10 +1796,6 @@ def _projection_sets_adjoint( Parameters -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate object_patches: np.ndarray Patched object view shifted_probes:np.ndarray @@ -1857,6 +1872,7 @@ def _adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, use_projection_scheme: bool, step_size: float, @@ -1895,7 +1911,6 @@ def _adjoint( updated_probe: np.ndarray Updated probe estimate """ - if use_projection_scheme: current_object, current_probe = self._projection_sets_adjoint( current_object, @@ -1912,6 +1927,7 @@ def _adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, step_size, normalization_min, @@ -1923,6 +1939,8 @@ def _adjoint( def _position_correction( self, current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, shifted_probes, overlap, amplitudes, @@ -1961,6 +1979,7 @@ def _position_correction( """ xp = self._xp + asnumpy = self._asnumpy # unperturbed overlap_fft = xp.fft.fft2(overlap) @@ -1973,22 +1992,22 @@ def _position_correction( difference_intensity = (measured_intensity - estimated_intensity).reshape( flat_shape ) - vectorized_patch_indices_row = self._vectorized_patch_indices_row.copy() - vectorized_patch_indices_col = self._vectorized_patch_indices_col.copy() # dx overlap projection perturbation - self._vectorized_patch_indices_row = ( - vectorized_patch_indices_row + 1 - ) % self._object_shape[0] - _, _, overlap_dx = self._overlap_projection(current_object, shifted_probes) - self._vectorized_patch_indices_row = vectorized_patch_indices_row.copy() + _, _, overlap_dx = self._overlap_projection( + current_object, + (vectorized_patch_indices_row + 1) % self._object_shape[0], + vectorized_patch_indices_col, + shifted_probes, + ) # dy overlap projection perturbation - self._vectorized_patch_indices_col = ( - vectorized_patch_indices_col + 1 - ) % self._object_shape[1] - _, _, overlap_dy = self._overlap_projection(current_object, shifted_probes) - self._vectorized_patch_indices_col = vectorized_patch_indices_col.copy() + _, _, overlap_dy = self._overlap_projection( + current_object, + vectorized_patch_indices_row, + (vectorized_patch_indices_col + 1) % self._object_shape[1], + shifted_probes, + ) # partial intensities overlap_dx_fft = overlap_fft - xp.fft.fft2(overlap_dx) @@ -2029,12 +2048,16 @@ def _position_correction( max_position_total_distance /= xp.sqrt( self.sampling[0] ** 2 + self.sampling[1] ** 2 ) - deltas = current_positions - positions_update - current_positions_initial + deltas = ( + xp.asarray(current_positions - current_positions_initial) + - positions_update + ) dsts = xp.linalg.norm(deltas, axis=1) outlier_ind = dsts > max_position_total_distance positions_update[outlier_ind] = 0 - current_positions -= positions_update + current_positions -= asnumpy(positions_update) + return current_positions def _return_self_consistency_errors( @@ -2044,6 +2067,8 @@ def _return_self_consistency_errors( """Compute the self-consistency errors for each probe position""" xp = self._xp + xp_storage = self._xp_storage + device = self._device asnumpy = self._asnumpy # Batch-size @@ -2052,36 +2077,40 @@ def _return_self_consistency_errors( # Re-initialize fractional positions and vector patches errors = np.array([]) - positions_px = self._positions_px.copy() for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device(self._amplitudes[start:end], device) # Overlaps - shifted_probes = self._return_shifted_probes(self._probe) - _, _, overlap = self._overlap_projection(self._object, shifted_probes) + shifted_probes = self._return_shifted_probes( + self._probe, positions_px_fractional + ) + _, _, overlap = self._overlap_projection( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) # Normalized mean-squared errors batch_errors = xp.sum( - xp.abs(amplitudes - farfield_amplitudes) ** 2, axis=(-2, -1) + xp.abs(amplitudes_device - farfield_amplitudes) ** 2, axis=(-2, -1) ) errors = np.hstack((errors, batch_errors)) - self._positions_px = positions_px.copy() errors /= self._mean_diffraction_intensity return asnumpy(errors) diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py index ab783cc29..d45f9b69f 100644 --- a/py4DSTEM/process/phase/ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -5,7 +5,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.process.phase.utils import AffineTransform +from py4DSTEM.process.phase.utils import AffineTransform, copy_to_device from py4DSTEM.visualize.vis_special import ( Complex2RGB, add_colorbar_arg, @@ -534,6 +534,8 @@ def visualize( **kwargs, ) + self.clear_device_mem(self._device, self._clear_fft_cache) + return self def show_updated_positions( @@ -661,6 +663,7 @@ def show_uncertainty_visualization( """Plot uncertainty visualization using self-consistency errors""" xp = self._xp + device = self._device asnumpy = self._asnumpy gaussian_filter = self._scipy.ndimage.gaussian_filter @@ -684,7 +687,8 @@ def show_uncertainty_visualization( ) tf = AffineTransform(angle=angle) - rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + positions_px = copy_to_device(self._positions_px, device) + rotated_points = tf(positions_px, origin=positions_px.mean(0), xp=xp) padding = xp.min(rotated_points, axis=0).astype("int") @@ -839,3 +843,5 @@ def show_uncertainty_visualization( ax.xaxis.set_ticks_position("bottom") spec.tight_layout(fig) + + self.clear_device_mem(device, self._clear_fft_cache) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 13151f02a..177451a25 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -32,6 +32,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -125,27 +126,18 @@ def __init__( positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "ptychographic_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): @@ -183,6 +175,7 @@ def __init__( self._positions_mask = positions_mask self._verbose = verbose self._device = device + self._storage = storage self._preprocessed = False # Class-specific Metadata @@ -208,6 +201,8 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -276,7 +271,14 @@ def preprocess( self: PtychographicReconstruction Self to accommodate chaining """ + + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -299,6 +301,7 @@ def preprocess( self._positions_mask = np.asarray(self._positions_mask, dtype="bool") # preprocess datacube + # all arrays computed/returned on 'cpu' ( self._datacube, self._vacuum_probe_intensity, @@ -315,7 +318,8 @@ def preprocess( ) # calibrations - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + # all arrays computed/returned on 'cpu' + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, force_scan_sampling=force_scan_sampling, @@ -330,6 +334,7 @@ def preprocess( ) # calculate CoM + # arrays computed/returned on device ( self._com_measured_x, self._com_measured_y, @@ -338,20 +343,19 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, ) # estimate rotation / transpose + # arrays computed/returned on device ( self._rotation_best_rad, self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -366,20 +370,34 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + self._com_measured_x = copy_to_device(self._com_measured_x, storage) + self._com_measured_y = copy_to_device(self._com_measured_y, storage) + self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) + self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) + self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) + self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) + self._com_x = copy_to_device(self._com_x, storage) + self._com_y = copy_to_device(self._com_y, storage) + # corner-center amplitudes + # arrays computed/returned on 'cpu' ( - self._amplitudes, + _amplitudes, self._mean_diffraction_intensity, self._crop_mask, ) = self._normalize_diffraction_intensities( - self._intensities, + _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, ) - # explicitly delete intensities namespace + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(_amplitudes, storage) + del _intensities + self._num_diffraction_patterns = self._amplitudes.shape[0] if region_of_interest_shape is not None: @@ -388,9 +406,9 @@ def preprocess( else: self._resample_exit_waves = False self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities # initialize probe positions + # arrays computed/returned on 'cpu' ( self._positions_px, self._object_padding_px, @@ -401,6 +419,7 @@ def preprocess( ) # initialize object + # arrays computed/returned on device directly self._object = self._initialize_object( self._object, self._positions_px, @@ -412,12 +431,11 @@ def preprocess( self._object_shape = self._object.shape # center probe positions - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + # arrays computed/returned on storage directly + self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 ) self._positions_px_initial = self._positions_px.copy() @@ -425,13 +443,8 @@ def preprocess( self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # set vectorized patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - # initialize probe + # arrays computed/returned on device directly self._probe, self._semiangle_cutoff = self._initialize_probe( self._probe, self._vacuum_probe_intensity, @@ -440,22 +453,28 @@ def preprocess( crop_patterns, ) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - # initialize aberrations + # arrays computed/returned on device directly self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, sampling=self.sampling, parameters=self._polar_parameters, - device=self._device, + device=device, )._evaluate_ctf() + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + positions_px_fractional = self._positions_px - xp_storage.round( + self._positions_px + ) + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, self._positions_px + ) + del shifted_probes # initialize object_fov_mask if object_fov_mask is None: @@ -464,10 +483,13 @@ def preprocess( self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) + del probe_overlap_blurred else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) @@ -510,7 +532,7 @@ def preprocess( ax1.set_title("Initial probe intensity") ax2.imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="gray", ) @@ -529,10 +551,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -548,9 +567,9 @@ def reconstruct( seed_random: int = None, step_size: float = 0.5, normalization_min: float = 1, - positions_step_size: float = 0.9, + positions_step_size: float = 0.5, pure_phase_object_iter: int = 0, - fix_com: bool = True, + fix_probe_com: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, constrain_probe_amplitude_iter: int = 0, @@ -560,6 +579,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, global_affine_transformation: bool = True, @@ -583,6 +603,8 @@ def reconstruct( store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -619,7 +641,7 @@ def reconstruct( Positions update step size pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - fix_com: bool, optional + fix_probe_com: bool, optional If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate @@ -639,6 +661,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -692,8 +716,13 @@ def reconstruct( self: PtychographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy # set and report reconstruction method ( @@ -729,10 +758,9 @@ def reconstruct( # batching shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -751,7 +779,7 @@ def reconstruct( if a0 == switch_object_iter: if self._object_type == "potential": self._object_type = "complex" - self._object = xp.exp(1j * self._object) + self._object = xp.exp(1j * self._object, dtype=xp.complex64) else: self._object_type = "potential" self._object = xp.angle(self._object) @@ -760,40 +788,39 @@ def reconstruct( if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - - positions_px = self._positions_px.copy()[shuffled_indices] - positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] - for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( shifted_probes, object_patches, overlap, - self._exit_waves, + exit_waves, batch_error, ) = self._forward( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, self._probe, - amplitudes, - self._exit_waves, + positions_px_fractional, + amplitudes_device, + None, use_projection_scheme, projection_a, projection_b, @@ -806,7 +833,8 @@ def reconstruct( self._probe, object_patches, shifted_probes, - self._exit_waves, + positions_px, + exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, @@ -815,13 +843,15 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( + self._positions_px[batch_indices] = self._position_correction( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, shifted_probes, overlap, - amplitudes, - self._positions_px, - positions_px_initial[start:end], + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -833,12 +863,12 @@ def reconstruct( error /= self._mean_diffraction_intensity * self._num_diffraction_patterns # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] self._object, self._probe, self._positions_px = self._constraints( self._object, self._probe, self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_initial, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -856,6 +886,7 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, @@ -880,7 +911,7 @@ def reconstruct( self.error_iterations.append(error.item()) if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) + self.object_iterations.append(asnumpy(self._object).copy()) self.probe_iterations.append(self.probe_centered) # store result @@ -888,8 +919,6 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c9db5f86b..32d20e812 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np from scipy.fft import dctn, idctn -from scipy.ndimage import zoom +from scipy.ndimage import gaussian_filter, uniform_filter1d, zoom from scipy.optimize import curve_fit try: @@ -12,13 +12,18 @@ from cupyx.scipy.fft import dctn as dctn_cp from cupyx.scipy.fft import idctn as idctn_cp from cupyx.scipy.ndimage import zoom as zoom_cp + + get_array_module = cp.get_array_module except (ImportError, ModuleNotFoundError): cp = None + def get_array_module(*args): + return np + + from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from scipy.ndimage import gaussian_filter, uniform_filter1d # fmt: off @@ -406,16 +411,13 @@ def get_scattering_angles(self): def get_spatial_frequencies(self): xp = self._xp - kx, ky = spatial_frequencies(self._gpts, self._sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) + kx, ky = spatial_frequencies(self._gpts, self._sampling, xp) return kx, ky def polar_coordinates(self, x, y): """Calculate a polar grid for a given Cartesian grid.""" xp = self._xp alpha = xp.sqrt(x[:, None] ** 2 + y[None, :] ** 2) - # phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1))) # bug in abtem-legacy and py4DSTEM<=0.14.9 phi = xp.arctan2(y[None, :], x[:, None]) return alpha, phi @@ -443,7 +445,7 @@ def visualize(self, **kwargs): return self -def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): +def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float], xp=np): """ Calculate spatial frequencies of a grid. @@ -460,7 +462,7 @@ def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): """ return tuple( - np.fft.fftfreq(n, d).astype(np.float32) for n, d in zip(gpts, sampling) + xp.fft.fftfreq(n, d).astype(xp.float32) for n, d in zip(gpts, sampling) ) @@ -492,16 +494,14 @@ def fourier_translation_operator( if len(positions_shape) == 1: positions = positions[None] - kx, ky = spatial_frequencies(shape, (1.0, 1.0)) - kx = kx.reshape((1, -1, 1)) - ky = ky.reshape((1, 1, -1)) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) + kx, ky = spatial_frequencies(shape, (1.0, 1.0), xp=xp) positions = xp.asarray(positions, dtype=xp.float32) - x = positions[:, 0].reshape((-1,) + (1, 1)) - y = positions[:, 1].reshape((-1,) + (1, 1)) + x = positions[:, 0].ravel()[:, None, None] + y = positions[:, 1].ravel()[:, None, None] - result = xp.exp(-2.0j * np.pi * kx * x) * xp.exp(-2.0j * np.pi * ky * y) + result = xp.exp(-2.0j * np.pi * kx[None, :, None] * x) * xp.exp( + -2.0j * np.pi * ky[None, None, :] * y + ) if len(positions_shape) == 1: return result[0] @@ -2345,3 +2345,23 @@ def partition_list(lst, size): """Partitions lst into chunks of size. Returns a generator.""" for i in range(0, len(lst), size): yield lst[i : i + size] + + +def copy_to_device(array, device="cpu"): + """Copies array to device. Default allows one to use this as asnumpy()""" + xp = get_array_module(array) + + if xp is np: + if device == "cpu": + return np.asarray(array) + elif device == "gpu": + return cp.asarray(array) + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + else: + if device == "cpu": + return cp.asnumpy(array) + elif device == "gpu": + return cp.asarray(array) + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") From c424d8df24ca9d32652e5d1fbb0627c30efc54e3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 10 Jan 2024 22:17:47 -0800 Subject: [PATCH 075/128] actually enable overwriting device Former-commit-id: 5f0da128dacf285db9c18d621e671b3a81700041 --- py4DSTEM/process/phase/phase_base_class.py | 8 +++++++- py4DSTEM/process/phase/singleslice_ptychography.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 2ce1d5bec..7506e941b 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -56,6 +56,10 @@ def set_device(self, device, clear_fft_cache): Self to enable chaining """ + if device is None: + self._clear_fft_cache = clear_fft_cache + return self + if device == "cpu": import scipy @@ -68,9 +72,10 @@ def set_device(self, device, clear_fft_cache): self._xp = cp self._scipy = scipy - elif device is not None: + else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device self._clear_fft_cache = clear_fft_cache return self @@ -102,6 +107,7 @@ def set_storage(self, storage): raise ValueError(f"storage must be either 'cpu' or 'gpu', not {storage}") self._asnumpy = copy_to_device + self._storage = storage return self diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 177451a25..6bfaecbe4 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -174,8 +174,6 @@ def __init__( self._object_padding_px = object_padding_px self._positions_mask = positions_mask self._verbose = verbose - self._device = device - self._storage = storage self._preprocessed = False # Class-specific Metadata @@ -719,6 +717,18 @@ def reconstruct( # handle device/storage self.set_device(device, clear_fft_cache) + if device is not None: # TO-DO: abstract away + self._known_aberrations_array = copy_to_device( + self._known_aberrations_array, device + ) + self._object = copy_to_device(self._object, device) + self._object_initial = copy_to_device(self._object_initial, device) + self._probe = copy_to_device(self._probe, device) + self._probe_initial = copy_to_device(self._probe_initial, device) + self._probe_initial_aperture = copy_to_device( + self._probe_initial_aperture, device + ) + xp = self._xp xp_storage = self._xp_storage device = self._device From c8eaae5703cbeefba3874d50a2450ecb77d95061 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 10:20:38 -0800 Subject: [PATCH 076/128] fixing single-slice projection sets bugs with storage refactor Former-commit-id: cc6619f6fea0ca6f7ad65f8f4d526cd397b4a6d4 --- py4DSTEM/process/phase/phase_base_class.py | 9 ++- .../process/phase/ptychographic_methods.py | 11 +++- .../process/phase/singleslice_ptychography.py | 63 ++++++++----------- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 7506e941b..0826d679e 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -122,6 +122,12 @@ def clear_device_mem(self, device, clear_fft_cache): xp._default_memory_pool.free_all_blocks() xp._default_pinned_memory_pool.free_all_blocks() + def copy_attributes_to_device(self, attrs, device): + """Utility function to copy a set of attrs to device""" + for attr in attrs: + array = copy_to_device(getattr(self, attr), device) + setattr(self, attr, array) + def attach_datacube(self, datacube: DataCube): """ Attaches a datacube to a class initialized without one. @@ -1818,7 +1824,8 @@ def _sum_overlapping_patches_bincounts_base( counts = xp.bincount( indices.ravel(), weights=flat_weights, minlength=np.prod(object_shape) ) - return xp.reshape(counts, object_shape) + counts = xp.reshape(counts, object_shape).astype(xp.float32) + return counts def _sum_overlapping_patches_bincounts(self, patches: np.ndarray, positions_px): """ diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index c59221166..48945c22a 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1786,6 +1786,7 @@ def _projection_sets_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, normalization_min, fix_probe, @@ -1817,7 +1818,8 @@ def _projection_sets_adjoint( xp = self._xp probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 + xp.abs(shifted_probes) ** 2, + positions_px, ) probe_normalization = 1 / xp.sqrt( 1e-16 @@ -1833,14 +1835,16 @@ def _projection_sets_adjoint( * xp.conj(object_patches) * xp.conj(shifted_probes) * exit_waves - ) + ), + positions_px, ) * probe_normalization ) else: current_object = ( self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves + xp.conj(shifted_probes) * exit_waves, + positions_px, ) * probe_normalization ) @@ -1917,6 +1921,7 @@ def _adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, normalization_min, fix_probe, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 6bfaecbe4..f017a1f4f 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -93,13 +93,17 @@ class SingleslicePtychography( If None, initialized to a grid scan verbose: bool, optional If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') positions_mask: np.ndarray, optional Boolean real space mask to select positions in datacube to skip for reconstruction + device: str, optional + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name kwargs: @@ -205,16 +209,6 @@ def preprocess( ): """ Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. Parameters ---------- @@ -263,6 +257,10 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -299,7 +297,6 @@ def preprocess( self._positions_mask = np.asarray(self._positions_mask, dtype="bool") # preprocess datacube - # all arrays computed/returned on 'cpu' ( self._datacube, self._vacuum_probe_intensity, @@ -316,7 +313,6 @@ def preprocess( ) # calibrations - # all arrays computed/returned on 'cpu' _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, @@ -332,7 +328,6 @@ def preprocess( ) # calculate CoM - # arrays computed/returned on device ( self._com_measured_x, self._com_measured_y, @@ -348,7 +343,6 @@ def preprocess( ) # estimate rotation / transpose - # arrays computed/returned on device ( self._rotation_best_rad, self._rotation_best_transpose, @@ -379,7 +373,6 @@ def preprocess( self._com_y = copy_to_device(self._com_y, storage) # corner-center amplitudes - # arrays computed/returned on 'cpu' ( _amplitudes, self._mean_diffraction_intensity, @@ -406,7 +399,6 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) # initialize probe positions - # arrays computed/returned on 'cpu' ( self._positions_px, self._object_padding_px, @@ -417,7 +409,6 @@ def preprocess( ) # initialize object - # arrays computed/returned on device directly self._object = self._initialize_object( self._object, self._positions_px, @@ -429,7 +420,6 @@ def preprocess( self._object_shape = self._object.shape # center probe positions - # arrays computed/returned on storage directly self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px -= ( @@ -442,7 +432,6 @@ def preprocess( self._positions_initial[:, 1] *= self.sampling[1] # initialize probe - # arrays computed/returned on device directly self._probe, self._semiangle_cutoff = self._initialize_probe( self._probe, self._vacuum_probe_intensity, @@ -452,7 +441,6 @@ def preprocess( ) # initialize aberrations - # arrays computed/returned on device directly self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -717,17 +705,16 @@ def reconstruct( # handle device/storage self.set_device(device, clear_fft_cache) - if device is not None: # TO-DO: abstract away - self._known_aberrations_array = copy_to_device( - self._known_aberrations_array, device - ) - self._object = copy_to_device(self._object, device) - self._object_initial = copy_to_device(self._object_initial, device) - self._probe = copy_to_device(self._probe, device) - self._probe_initial = copy_to_device(self._probe_initial, device) - self._probe_initial_aperture = copy_to_device( - self._probe_initial_aperture, device - ) + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) xp = self._xp xp_storage = self._xp_storage @@ -821,7 +808,7 @@ def reconstruct( shifted_probes, object_patches, overlap, - exit_waves, + self._exit_waves, batch_error, ) = self._forward( self._object, @@ -830,7 +817,7 @@ def reconstruct( self._probe, positions_px_fractional, amplitudes_device, - None, + self._exit_waves, use_projection_scheme, projection_a, projection_b, @@ -844,7 +831,7 @@ def reconstruct( object_patches, shifted_probes, positions_px, - exit_waves, + self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, @@ -929,6 +916,10 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + self.clear_device_mem(device, self._clear_fft_cache) return self From c95a7e05b4aa7769a038e14e49d190b87d6e7194 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 11:13:08 -0800 Subject: [PATCH 077/128] cleaning up multislice for device. changed propagator tilt convention to negative mrad Former-commit-id: 4677df5913e27e2d7af4c071a710478359d2b63e --- .../process/phase/multislice_ptychography.py | 79 ++++++++----------- .../process/phase/ptychographic_methods.py | 10 +-- 2 files changed, 36 insertions(+), 53 deletions(-) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 8b6a267a6..b2b4ce4d3 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -35,6 +35,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -102,9 +103,9 @@ class MultislicePtychography( Probe positions in Å for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in degrees) + x tilt of propagator in mrad theta_y: float - y tilt of propagator (in degrees) + y tilt of propagator in mrad middle_focus: bool if True, adds half the sample thickness to the defocus object_type: str, optional @@ -115,7 +116,11 @@ class MultislicePtychography( verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name kwargs: @@ -123,7 +128,12 @@ class MultislicePtychography( """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_slice_thicknesses") + _class_specific_metadata = ( + "_num_slices", + "_slice_thicknesses", + "_theta_x", + "_theta_y", + ) def __init__( self, @@ -147,59 +157,23 @@ def __init__( positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "multi-slice_ptychographic_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) - if np.isscalar(slice_thicknesses): - mean_slice_thickness = slice_thicknesses - else: - mean_slice_thickness = np.mean(slice_thicknesses) - - if middle_focus: - if "defocus" in kwargs: - kwargs["defocus"] += mean_slice_thickness * num_slices / 2 - elif "C10" in kwargs: - kwargs["C10"] -= mean_slice_thickness * num_slices / 2 - elif polar_parameters is not None and "defocus" in polar_parameters: - polar_parameters["defocus"] = ( - polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 - ) - elif polar_parameters is not None and "C10" in polar_parameters: - polar_parameters["C10"] = ( - polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 - ) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - slice_thicknesses = np.array(slice_thicknesses) if slice_thicknesses.shape == (): slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) @@ -211,6 +185,18 @@ def __init__( ) ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if middle_focus: + half_thickness = slice_thicknesses.mean() * num_slices / 2 + self._polar_parameters["C10"] -= half_thickness + if object_type != "potential" and object_type != "complex": raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" @@ -234,7 +220,6 @@ def __init__( self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose - self._device = device self._preprocessed = False # Class-specific Metadata diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 48945c22a..9d14ae667 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -329,9 +329,9 @@ def _precompute_propagator_arrays( slice_thicknesses: Sequence[float] Array of slice thicknesses in A theta_x: float, optional - x tilt of propagator (in degrees) + x tilt of propagator in mrad theta_y: float, optional - y tilt of propagator (in degrees) + y tilt of propagator in mrad Returns ------- @@ -361,15 +361,13 @@ def _precompute_propagator_arrays( ) if theta_x is not None: - theta_x = np.deg2rad(theta_x) propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + 1.0j * (-2 * kx[:, None] * np.pi * dz * np.tan(theta_x / 1e3)) ) if theta_y is not None: - theta_y = np.deg2rad(theta_y) propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + 1.0j * (-2 * ky[None] * np.pi * dz * np.tan(theta_y / 1e3)) ) return propagators From 1ab99136f86f52ad11af39764d6052457c18415f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 11:52:40 -0800 Subject: [PATCH 078/128] added storage support to multislice Former-commit-id: d8c3c299e24bb538fbabd18f333cdee9eb7afcc1 --- .../process/phase/multislice_ptychography.py | 173 ++++++++++++------ .../process/phase/ptychographic_methods.py | 75 +++++--- .../process/phase/singleslice_ptychography.py | 9 +- 3 files changed, 167 insertions(+), 90 deletions(-) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index b2b4ce4d3..3306a0f87 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -249,6 +249,8 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -311,13 +313,23 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: MultislicePtychographicReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -356,7 +368,7 @@ def preprocess( ) # calibrations - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, force_scan_sampling=force_scan_sampling, @@ -379,7 +391,7 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, @@ -391,8 +403,6 @@ def preprocess( self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -407,20 +417,33 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + self._com_measured_x = copy_to_device(self._com_measured_x, storage) + self._com_measured_y = copy_to_device(self._com_measured_y, storage) + self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) + self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) + self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) + self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) + self._com_x = copy_to_device(self._com_x, storage) + self._com_y = copy_to_device(self._com_y, storage) + # corner-center amplitudes ( - self._amplitudes, + _amplitudes, self._mean_diffraction_intensity, self._crop_mask, ) = self._normalize_diffraction_intensities( - self._intensities, + _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, ) - # explicitly delete namespace + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(_amplitudes, storage) + del _intensities + self._num_diffraction_patterns = self._amplitudes.shape[0] if region_of_interest_shape is not None: @@ -429,7 +452,6 @@ def preprocess( else: self._resample_exit_waves = False self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities # initialize probe positions ( @@ -454,25 +476,18 @@ def preprocess( self._object_shape = self._object.shape[-2:] # center probe positions - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # set vectorized patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - # initialize probe self._probe, self._semiangle_cutoff = self._initialize_probe( self._probe, @@ -482,9 +497,6 @@ def preprocess( crop_patterns, ) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -494,6 +506,9 @@ def preprocess( device=self._device, )._evaluate_ctf() + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, @@ -505,9 +520,14 @@ def preprocess( ) # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + positions_px_fractional = self._positions_px - xp_storage.round( + self._positions_px + ) + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, self._positions_px + ) + del shifted_probes if object_fov_mask is None: gaussian_filter = self._scipy.ndimage.gaussian_filter @@ -515,10 +535,14 @@ def preprocess( self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) + del probe_overlap_blurred else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) @@ -588,7 +612,7 @@ def preprocess( ax2.set_title("Propagated probe intensity") ax3.imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="Greys_r", ) @@ -607,10 +631,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -627,7 +648,7 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, - fix_com: bool = True, + fix_probe_com: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, constrain_probe_amplitude_iter: int = 0, @@ -637,6 +658,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, global_affine_transformation: bool = True, @@ -667,6 +689,8 @@ def reconstruct( store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -701,7 +725,7 @@ def reconstruct( Probe normalization minimum as a fraction of the maximum overlap intensity positions_step_size: float, optional Positions update step size - fix_com: bool, optional + fix_probe_com: bool, optional If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate @@ -721,6 +745,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -783,14 +809,35 @@ def reconstruct( If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: MultislicePtychographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy # set and report reconstruction method ( @@ -826,10 +873,9 @@ def reconstruct( # Batching shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -857,27 +903,23 @@ def reconstruct( if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - - positions_px = self._positions_px.copy()[shuffled_indices] - positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] - for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( @@ -888,8 +930,11 @@ def reconstruct( batch_error, ) = self._forward( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, self._probe, - amplitudes, + positions_px_fractional, + amplitudes_device, self._exit_waves, use_projection_scheme, projection_a, @@ -903,6 +948,7 @@ def reconstruct( self._probe, object_patches, shifted_probes, + positions_px, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -912,13 +958,15 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( + self._positions_px[batch_indices] = self._position_correction( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, self._probe, overlap, - amplitudes, - self._positions_px, - positions_px_initial[start:end], + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -930,12 +978,12 @@ def reconstruct( error /= self._mean_diffraction_intensity * self._num_diffraction_patterns # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] self._object, self._probe, self._positions_px = self._constraints( self._object, self._probe, self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_initial, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -953,6 +1001,7 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, @@ -993,8 +1042,10 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 9d14ae667..286406dd6 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -402,6 +402,7 @@ def _initialize_object( """ """ # explicit read-only self attributes up-front xp = self._xp + object_padding_px = self._object_padding_px region_of_interest_shape = self._region_of_interest_shape @@ -2078,7 +2079,6 @@ def _return_self_consistency_errors( if max_batch_size is None: max_batch_size = self._num_diffraction_patterns - # Re-initialize fractional positions and vector patches errors = np.array([]) for start, end in generate_batches( @@ -2125,7 +2125,13 @@ class Object2p5DProbeMethodsMixin: Overwrites ObjectNDProbeMethodsMixin. """ - def _overlap_projection(self, current_object, shifted_probes_in): + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes_in, + ): """ Ptychographic overlap projection method. @@ -2148,17 +2154,15 @@ def _overlap_projection(self, current_object, shifted_probes_in): xp = self._xp - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ + object_patches = current_object[ :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, + vectorized_patch_indices_row, + vectorized_patch_indices_col, ] + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + shifted_probes = xp.empty_like(object_patches) shifted_probes[0] = shifted_probes_in @@ -2180,6 +2184,7 @@ def _gradient_descent_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, step_size, normalization_min, @@ -2223,7 +2228,8 @@ def _gradient_descent_adjoint( # object-update probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 + xp.abs(probe) ** 2, + positions_px, ) probe_normalization = 1 / xp.sqrt( @@ -2235,13 +2241,16 @@ def _gradient_descent_adjoint( if self._object_type == "potential": current_object[s] += step_size * ( self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves), + positions_px, ) * probe_normalization ) else: current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts(xp.conj(probe) * exit_waves) + self._sum_overlapping_patches_bincounts( + xp.conj(probe) * exit_waves, positions_px + ) * probe_normalization ) @@ -2282,6 +2291,7 @@ def _projection_sets_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, normalization_min, fix_probe, @@ -2325,7 +2335,8 @@ def _projection_sets_adjoint( # object-update probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 + xp.abs(probe) ** 2, + positions_px, ) probe_normalization = 1 / xp.sqrt( 1e-16 @@ -2336,14 +2347,16 @@ def _projection_sets_adjoint( if self._object_type == "potential": current_object[s] = ( self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy), + positions_px, ) * probe_normalization ) else: current_object[s] = ( self._sum_overlapping_patches_bincounts( - xp.conj(probe) * exit_waves_copy + xp.conj(probe) * exit_waves_copy, + positions_px, ) * probe_normalization ) @@ -2402,13 +2415,13 @@ def show_transmitted_probe( """ xp = self._xp + xp_storage = self._xp_storage + device = self._device asnumpy = self._asnumpy if max_batch_size is None: max_batch_size = self._num_diffraction_patterns - positions_px = self._positions_px.copy() - mean_transmitted = xp.zeros_like(self._probe) intensities_compare = [np.inf, 0] @@ -2416,18 +2429,24 @@ def show_transmitted_probe( self._num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) # overlaps - shifted_probes = self._return_shifted_probes(self._probe) - _, _, overlap = self._overlap_projection(self._object, shifted_probes) + shifted_probes = self._return_shifted_probes( + self._probe, positions_px_fractional + ) + _, _, overlap = self._overlap_projection( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) # store relevant arrays mean_transmitted += overlap.sum(0) @@ -2494,6 +2513,8 @@ def show_transmitted_probe( **kwargs, ) + self.clear_device_mem(device, self._clear_fft_cache) + class ObjectNDProbeMixedMethodsMixin: """ diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index f017a1f4f..f6d6454de 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -258,9 +258,9 @@ def preprocess( crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering device: str, optional - If not None, overwrites self._device to set device preprocess will be perfomed on. + if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional - If True, and device = 'gpu', clears the cached fft plan at the end of function calls + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -425,6 +425,7 @@ def preprocess( self._positions_px -= ( self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() @@ -696,6 +697,10 @@ def reconstruct( If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- From d82194ff0e03b1baa6b395992524dd98e01864c8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 13:01:04 -0800 Subject: [PATCH 079/128] adding mixed-state storage support Former-commit-id: a5226689b91f5a0106a6dd7f2af0febc3b82efda --- .../process/phase/mixedstate_ptychography.py | 193 +++++++++++------- .../process/phase/multislice_ptychography.py | 4 +- .../phase/ptychographic_constraints.py | 4 +- .../process/phase/ptychographic_methods.py | 40 ++-- .../process/phase/singleslice_ptychography.py | 4 +- 5 files changed, 149 insertions(+), 96 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 9ebe802ba..fd8273a37 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -35,6 +35,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -131,27 +132,18 @@ def __init__( positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "mixed-state_ptychographic_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): if num_probes is None: @@ -203,7 +195,6 @@ def __init__( self._object_padding_px = object_padding_px self._positions_mask = positions_mask self._verbose = verbose - self._device = device self._preprocessed = False # Class-specific Metadata @@ -230,6 +221,8 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -289,13 +282,23 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: PtychographicReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -334,7 +337,7 @@ def preprocess( ) # calibrations - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, force_scan_sampling=force_scan_sampling, @@ -357,7 +360,7 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, @@ -369,8 +372,6 @@ def preprocess( self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -385,20 +386,33 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + self._com_measured_x = copy_to_device(self._com_measured_x, storage) + self._com_measured_y = copy_to_device(self._com_measured_y, storage) + self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) + self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) + self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) + self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) + self._com_x = copy_to_device(self._com_x, storage) + self._com_y = copy_to_device(self._com_y, storage) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, ) = self._normalize_diffraction_intensities( - self._intensities, + _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, ) - # explicitly delete namespace + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + self._num_diffraction_patterns = self._amplitudes.shape[0] if region_of_interest_shape is not None: @@ -407,7 +421,6 @@ def preprocess( else: self._resample_exit_waves = False self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities # initialize probe positions ( @@ -431,25 +444,18 @@ def preprocess( self._object_shape = self._object.shape # center probe positions - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # set vectorized patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - # initialize probe self._probe, self._semiangle_cutoff = self._initialize_probe( self._probe, @@ -459,9 +465,6 @@ def preprocess( crop_patterns, ) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -471,10 +474,18 @@ def preprocess( device=self._device, )._evaluate_ctf() + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # overlaps - shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + positions_px_fractional = self._positions_px - xp_storage.round( + self._positions_px + ) + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, self._positions_px + ) + del shifted_probes if object_fov_mask is None: gaussian_filter = self._scipy.ndimage.gaussian_filter @@ -482,10 +493,14 @@ def preprocess( self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) + del probe_overlap_blurred else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) @@ -528,7 +543,7 @@ def preprocess( add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="Greys_r", ) @@ -547,10 +562,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -568,7 +580,7 @@ def reconstruct( normalization_min: float = 1, positions_step_size: float = 0.9, pure_phase_object_iter: int = 0, - fix_com: bool = True, + fix_probe_com: bool = True, orthogonalize_probe: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, @@ -579,6 +591,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, global_affine_transformation: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, @@ -602,6 +615,8 @@ def reconstruct( store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -638,7 +653,7 @@ def reconstruct( Positions update step size pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - fix_com: bool, optional + fix_probe_com: bool, optional If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate @@ -658,6 +673,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -705,14 +722,34 @@ def reconstruct( If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: PtychographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy # set and report reconstruction method ( @@ -748,10 +785,9 @@ def reconstruct( # Batching shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -779,27 +815,23 @@ def reconstruct( if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - - positions_px = self._positions_px.copy()[shuffled_indices] - positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] - for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( @@ -810,8 +842,11 @@ def reconstruct( batch_error, ) = self._forward( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, self._probe, - amplitudes, + positions_px_fractional, + amplitudes_device, self._exit_waves, use_projection_scheme, projection_a, @@ -825,6 +860,7 @@ def reconstruct( self._probe, object_patches, shifted_probes, + positions_px, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -834,13 +870,15 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( + self._positions_px[batch_indices] = self._position_correction( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, shifted_probes, overlap, - amplitudes, - self._positions_px, - positions_px_initial[start:end], + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -852,12 +890,12 @@ def reconstruct( error /= self._mean_diffraction_intensity * self._num_diffraction_patterns # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] self._object, self._probe, self._positions_px = self._constraints( self._object, self._probe, self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_initial, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -875,6 +913,7 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, @@ -908,8 +947,10 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 3306a0f87..3482fc960 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -429,7 +429,7 @@ def preprocess( # corner-center amplitudes ( - _amplitudes, + self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, ) = self._normalize_diffraction_intensities( @@ -441,7 +441,7 @@ def preprocess( ) # explicitly transfer arrays to storage - self._amplitudes = copy_to_device(_amplitudes, storage) + self._amplitudes = copy_to_device(self._amplitudes, storage) del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index eb6e95444..eb1b69951 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1214,7 +1214,7 @@ def _probe_orthogonalization_constraint(self, current_probe): def _probe_constraints( self, current_probe, - fix_com, + fix_probe_com, fit_probe_aberrations, fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, @@ -1233,7 +1233,7 @@ def _probe_constraints( """ProbeMixedConstraints wrapper function""" # CoM corner-centering - if fix_com: + if fix_probe_com: current_probe = self._probe_center_of_mass_constraint(current_probe) # Fourier phase (aberrations) fitting diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 286406dd6..e6898e362 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -2522,7 +2522,13 @@ class ObjectNDProbeMixedMethodsMixin: Overwrites ObjectNDProbeMethodsMixin. """ - def _overlap_projection(self, current_object, shifted_probes): + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): """ Ptychographic overlap projection method. @@ -2545,15 +2551,13 @@ def _overlap_projection(self, current_object, shifted_probes): xp = self._xp - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + object_patches = current_object[ + vectorized_patch_indices_row, vectorized_patch_indices_col ] + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) return shifted_probes, object_patches, overlap @@ -2694,6 +2698,7 @@ def _gradient_descent_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, step_size, normalization_min, @@ -2736,7 +2741,8 @@ def _gradient_descent_adjoint( for i_probe in range(self._num_probes): probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 + xp.abs(shifted_probes[:, i_probe]) ** 2, + positions_px, ) if self._object_type == "potential": object_update += step_size * self._sum_overlapping_patches_bincounts( @@ -2745,11 +2751,13 @@ def _gradient_descent_adjoint( * xp.conj(object_patches) * xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) + ), + positions_px, ) else: object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe], + positions_px, ) probe_normalization = 1 / xp.sqrt( 1e-16 @@ -2786,6 +2794,7 @@ def _projection_sets_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, normalization_min, fix_probe, @@ -2825,7 +2834,8 @@ def _projection_sets_adjoint( for i_probe in range(self._num_probes): probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 + xp.abs(shifted_probes[:, i_probe]) ** 2, + positions_px, ) if self._object_type == "potential": current_object += self._sum_overlapping_patches_bincounts( @@ -2834,11 +2844,13 @@ def _projection_sets_adjoint( * xp.conj(object_patches) * xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) + ), + positions_px, ) else: current_object += self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe], + positions_px, ) probe_normalization = 1 / xp.sqrt( 1e-16 diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index f6d6454de..48a39714e 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -374,7 +374,7 @@ def preprocess( # corner-center amplitudes ( - _amplitudes, + self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, ) = self._normalize_diffraction_intensities( @@ -386,7 +386,7 @@ def preprocess( ) # explicitly transfer arrays to storage - self._amplitudes = copy_to_device(_amplitudes, storage) + self._amplitudes = copy_to_device(self._amplitudes, storage) del _intensities self._num_diffraction_patterns = self._amplitudes.shape[0] From 94d37096458a1ce699a5dbfac25056ad0d21740c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 13:41:01 -0800 Subject: [PATCH 080/128] added storage to mixed-multislice Former-commit-id: 7dbb93177d6cbe5137483fe9459d99bc4d09e29d --- .../mixedstate_multislice_ptychography.py | 241 ++++++++++-------- .../process/phase/ptychographic_methods.py | 42 +-- 2 files changed, 163 insertions(+), 120 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index a2372a772..2ce756168 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -38,6 +38,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -110,9 +111,9 @@ class MixedstateMultislicePtychography( Probe positions in Å for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in degrees) + x tilt of propagator in mrad theta_y: float - y tilt of propagator (in degrees) + y tilt of propagator in mrad middle_focus: bool if True, adds half the sample thickness to the defocus object_type: str, optional @@ -124,6 +125,10 @@ class MixedstateMultislicePtychography( If True, class methods will inherit this and print additional information device: str, optional Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name kwargs: @@ -131,7 +136,13 @@ class MixedstateMultislicePtychography( """ # Class-specific Metadata - _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + _class_specific_metadata = ( + "_num_probes", + "_num_slices", + "_slice_thicknesses", + "_theta_x", + "_theta_y", + ) def __init__( self, @@ -156,27 +167,18 @@ def __init__( positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "multi-slice_ptychographic_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): if num_probes is None: @@ -197,33 +199,6 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) - if np.isscalar(slice_thicknesses): - mean_slice_thickness = slice_thicknesses - else: - mean_slice_thickness = np.mean(slice_thicknesses) - - if middle_focus: - if "defocus" in kwargs: - kwargs["defocus"] += mean_slice_thickness * num_slices / 2 - elif "C10" in kwargs: - kwargs["C10"] -= mean_slice_thickness * num_slices / 2 - elif polar_parameters is not None and "defocus" in polar_parameters: - polar_parameters["defocus"] = ( - polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 - ) - elif polar_parameters is not None and "C10" in polar_parameters: - polar_parameters["C10"] = ( - polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 - ) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - slice_thicknesses = np.array(slice_thicknesses) if slice_thicknesses.shape == (): slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) @@ -235,6 +210,18 @@ def __init__( ) ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if middle_focus: + half_thickness = slice_thicknesses.mean() * num_slices / 2 + self._polar_parameters["C10"] -= half_thickness + if object_type != "potential" and object_type != "complex": raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" @@ -258,7 +245,6 @@ def __init__( self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose - self._device = device self._preprocessed = False # Class-specific Metadata @@ -289,6 +275,8 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -351,13 +339,23 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: MixedstateMultislicePtychographicReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -396,7 +394,7 @@ def preprocess( ) # calibrations - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, force_scan_sampling=force_scan_sampling, @@ -419,7 +417,7 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, @@ -431,8 +429,6 @@ def preprocess( self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -447,20 +443,33 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + self._com_measured_x = copy_to_device(self._com_measured_x, storage) + self._com_measured_y = copy_to_device(self._com_measured_y, storage) + self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) + self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) + self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) + self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) + self._com_x = copy_to_device(self._com_x, storage) + self._com_y = copy_to_device(self._com_y, storage) + # corner-center amplitudes ( self._amplitudes, self._mean_diffraction_intensity, self._crop_mask, ) = self._normalize_diffraction_intensities( - self._intensities, + _intensities, self._com_fitted_x, self._com_fitted_y, self._positions_mask, crop_patterns, ) - # explicitly delete namespace + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + self._num_diffraction_patterns = self._amplitudes.shape[0] if region_of_interest_shape is not None: @@ -469,7 +478,6 @@ def preprocess( else: self._resample_exit_waves = False self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities # initialize probe positions ( @@ -494,25 +502,18 @@ def preprocess( self._object_shape = self._object.shape[-2:] # center probe positions - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 ) + self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px_initial = self._positions_px.copy() self._positions_initial = self._positions_px_initial.copy() self._positions_initial[:, 0] *= self.sampling[0] self._positions_initial[:, 1] *= self.sampling[1] - # set vectorized patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - # initialize probe self._probe, self._semiangle_cutoff = self._initialize_probe( self._probe, @@ -522,9 +523,6 @@ def preprocess( crop_patterns, ) - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - # initialize aberrations self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -534,6 +532,9 @@ def preprocess( device=self._device, )._evaluate_ctf() + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, @@ -545,9 +546,14 @@ def preprocess( ) # overlaps - shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + positions_px_fractional = self._positions_px - xp_storage.round( + self._positions_px + ) + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, self._positions_px + ) + del shifted_probes if object_fov_mask is None: gaussian_filter = self._scipy.ndimage.gaussian_filter @@ -555,10 +561,14 @@ def preprocess( self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) + del probe_overlap_blurred else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) @@ -628,7 +638,7 @@ def preprocess( ax2.set_title("Propagated probe[0] intensity") ax3.imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="Greys_r", ) @@ -647,10 +657,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -667,7 +674,7 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, - fix_com: bool = True, + fix_probe_com: bool = True, orthogonalize_probe: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, @@ -678,6 +685,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, global_affine_transformation: bool = True, @@ -708,6 +716,8 @@ def reconstruct( store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -830,8 +840,25 @@ def reconstruct( self: MultislicePtychographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy # set and report reconstruction method ( @@ -867,10 +894,9 @@ def reconstruct( # batching shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -898,27 +924,23 @@ def reconstruct( if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - - positions_px = self._positions_px.copy()[shuffled_indices] - positions_px_initial = self._positions_px_initial.copy()[shuffled_indices] - for start, end in generate_batches( self._num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( @@ -929,8 +951,11 @@ def reconstruct( batch_error, ) = self._forward( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, self._probe, - amplitudes, + positions_px_fractional, + amplitudes_device, self._exit_waves, use_projection_scheme, projection_a, @@ -944,6 +969,7 @@ def reconstruct( self._probe, object_patches, shifted_probes, + positions_px, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -953,13 +979,15 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( + self._positions_px[batch_indices] = self._position_correction( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, self._probe, overlap, - amplitudes, - self._positions_px, - positions_px_initial[start:end], + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -971,12 +999,12 @@ def reconstruct( error /= self._mean_diffraction_intensity * self._num_diffraction_patterns # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] self._object, self._probe, self._positions_px = self._constraints( self._object, self._probe, self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_initial, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -994,6 +1022,7 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, @@ -1038,8 +1067,10 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index e6898e362..c1c10e6ab 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -2888,7 +2888,13 @@ class Object2p5DProbeMixedMethodsMixin: Overwrites ObjectNDProbeMethodsMixin and ObjectNDProbeMixedMethodsMixin. """ - def _overlap_projection(self, current_object, shifted_probes_in): + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes_in, + ): """ Ptychographic overlap projection method. @@ -2911,17 +2917,15 @@ def _overlap_projection(self, current_object, shifted_probes_in): xp = self._xp - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ + object_patches = current_object[ :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, + vectorized_patch_indices_row, + vectorized_patch_indices_col, ] + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + num_probe_positions = object_patches.shape[1] shifted_shape = ( @@ -2953,6 +2957,7 @@ def _gradient_descent_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, step_size, normalization_min, @@ -3000,7 +3005,8 @@ def _gradient_descent_adjoint( for i_probe in range(self._num_probes): probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 + xp.abs(probe[:, i_probe]) ** 2, + positions_px, ) if self._object_type == "potential": @@ -3012,14 +3018,16 @@ def _gradient_descent_adjoint( * xp.conj(obj) * xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] - ) + ), + positions_px, ) ) else: object_update += ( step_size * self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe], + positions_px, ) ) @@ -3068,6 +3076,7 @@ def _projection_sets_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, normalization_min, fix_probe, @@ -3114,7 +3123,8 @@ def _projection_sets_adjoint( for i_probe in range(self._num_probes): probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 + xp.abs(probe[:, i_probe]) ** 2, + positions_px, ) if self._object_type == "potential": @@ -3124,11 +3134,13 @@ def _projection_sets_adjoint( * xp.conj(obj) * xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] - ) + ), + positions_px, ) else: object_update += self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe], + positions_px, ) probe_normalization = 1 / xp.sqrt( From dab456fd9bf1613d6777453743ce4a9c1f4ce8fe Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 15:31:54 -0800 Subject: [PATCH 081/128] attrs copying cleanup Former-commit-id: c14966fb133289f72b82ec9eeaba4677a7e92079 --- .../mixedstate_multislice_ptychography.py | 19 ++++++++------- .../process/phase/mixedstate_ptychography.py | 24 ++++++++++++------- .../process/phase/multislice_ptychography.py | 19 ++++++++------- .../process/phase/singleslice_ptychography.py | 19 ++++++++------- 4 files changed, 48 insertions(+), 33 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 2ce756168..e44fc77de 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -444,14 +444,17 @@ def preprocess( ) # explicitly transfer arrays to storage - self._com_measured_x = copy_to_device(self._com_measured_x, storage) - self._com_measured_y = copy_to_device(self._com_measured_y, storage) - self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) - self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) - self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) - self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) - self._com_x = copy_to_device(self._com_x, storage) - self._com_y = copy_to_device(self._com_y, storage) + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) # corner-center amplitudes ( diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index fd8273a37..17e220e51 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -245,9 +245,12 @@ def preprocess( If None, no resampling of diffraction intenstities is performed reshaping_method: str, optional Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape: (int,int), optional Padded diffraction intensities shape. If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data dp_mask: ndarray, optional Mask for datacube intensities (Qx,Qy) fit_function: str, optional @@ -387,14 +390,17 @@ def preprocess( ) # explicitly transfer arrays to storage - self._com_measured_x = copy_to_device(self._com_measured_x, storage) - self._com_measured_y = copy_to_device(self._com_measured_y, storage) - self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) - self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) - self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) - self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) - self._com_x = copy_to_device(self._com_x, storage) - self._com_y = copy_to_device(self._com_y, storage) + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) # corner-center amplitudes ( diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 3482fc960..cd406139b 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -418,14 +418,17 @@ def preprocess( ) # explicitly transfer arrays to storage - self._com_measured_x = copy_to_device(self._com_measured_x, storage) - self._com_measured_y = copy_to_device(self._com_measured_y, storage) - self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) - self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) - self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) - self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) - self._com_x = copy_to_device(self._com_x, storage) - self._com_y = copy_to_device(self._com_y, storage) + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) # corner-center amplitudes ( diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 48a39714e..9ba285475 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -363,14 +363,17 @@ def preprocess( ) # explicitly transfer arrays to storage - self._com_measured_x = copy_to_device(self._com_measured_x, storage) - self._com_measured_y = copy_to_device(self._com_measured_y, storage) - self._com_fitted_x = copy_to_device(self._com_fitted_x, storage) - self._com_fitted_y = copy_to_device(self._com_fitted_y, storage) - self._com_normalized_x = copy_to_device(self._com_normalized_x, storage) - self._com_normalized_y = copy_to_device(self._com_normalized_y, storage) - self._com_x = copy_to_device(self._com_x, storage) - self._com_y = copy_to_device(self._com_y, storage) + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) # corner-center amplitudes ( From 61ec677bf8e6f2de4188913129bceededa3726e2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 15:54:01 -0800 Subject: [PATCH 082/128] added storage to dpc Former-commit-id: 8a84e3cc2094fdd28e2fb7a511c1a49acc463e09 --- py4DSTEM/process/phase/dpc.py | 169 +++++++++++---------- py4DSTEM/process/phase/phase_base_class.py | 2 - 2 files changed, 87 insertions(+), 84 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index f379e50d3..764aa4b5f 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -20,6 +20,8 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import copy_to_device +from py4DSTEM.visualize.vis_special import return_scaled_histogram_ordering warnings.simplefilter(action="always", category=UserWarning) @@ -42,7 +44,11 @@ class DPC(PhaseReconstruction): verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name """ @@ -54,26 +60,17 @@ def __init__( energy: float = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "dpc_reconstruction", ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) self.set_save_defaults() @@ -84,7 +81,6 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose - self._device = device self._preprocessed = False def to_h5(self, group): @@ -236,7 +232,7 @@ def preprocess( self, dp_mask: np.ndarray = None, padding_factor: float = 2, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, maximize_divergence: bool = False, fit_function: str = "plane", force_com_rotation: float = None, @@ -245,6 +241,8 @@ def preprocess( force_com_measured: Sequence[np.ndarray] = None, plot_center_of_mass: str = "default", plot_rotation: bool = True, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -286,7 +284,14 @@ def preprocess( self: DPCReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy # set additional metadata self._dp_mask = dp_mask @@ -305,7 +310,7 @@ def preprocess( data=np.empty(force_com_measured[0].shape + (1, 1)) ) - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=False, ) @@ -318,7 +323,7 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, @@ -330,8 +335,6 @@ def preprocess( self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -346,11 +349,23 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + ] + self.copy_attributes_to_device(attrs, storage) + # Object Initialization padded_object_shape = np.round( np.array(self._grid_scan_shape) * padding_factor ).astype("int") self._padded_object_phase = xp.zeros(padded_object_shape, dtype=xp.float32) + if self._object_phase is not None: self._padded_object_phase[ : self._grid_scan_shape[0], : self._grid_scan_shape[1] @@ -359,20 +374,23 @@ def preprocess( self._padded_object_phase_initial = self._padded_object_phase.copy() # Fourier coordinates and operators - kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]) - ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1]) + kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]).astype( + xp.float32 + ) + ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1]).astype( + xp.float32 + ) kya, kxa = xp.meshgrid(ky, kx) + k_den = kxa**2 + kya**2 k_den[0, 0] = np.inf k_den = 1 / k_den + self._kx_op = -1j * 0.25 * kxa * k_den self._ky_op = -1j * 0.25 * kya * k_den self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -415,6 +433,7 @@ def _forward( """ xp = self._xp + asnumpy = self._asnumpy dx, dy = self._scan_sampling # centered finite-differences @@ -433,8 +452,9 @@ def _forward( obj_dx[mask_inv] = 0 obj_dy[mask_inv] = 0 - new_error = xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) / ( - xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2) + new_error = asnumpy( + xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) + / (xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2)) ) return obj_dx, obj_dy, new_error, step_size @@ -519,8 +539,8 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): Constrained object estimate """ gaussian_filter = self._scipy.ndimage.gaussian_filter - gaussian_filter_sigma /= self.sampling[0] + current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object @@ -560,44 +580,13 @@ def _object_butterworth_constraint( if q_lowpass: env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - current_object_mean = xp.mean(current_object) + current_object_mean = xp.mean(current_object, axis=(-2, -1), keepdims=True) current_object -= current_object_mean current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) current_object += current_object_mean return xp.real(current_object) - def _object_anti_gridding_contraint(self, current_object): - """ - Zero outer pixels of object fft to remove gridding artifacts - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - # find indices to zero - width_x = current_object.shape[0] - width_y = current_object.shape[1] - ind_min_x = int(xp.floor(width_x / 2) - 2) - ind_max_x = int(xp.ceil(width_x / 2) + 2) - ind_min_y = int(xp.floor(width_y / 2) - 2) - ind_max_y = int(xp.ceil(width_y / 2) + 2) - - # zero pixels - object_fft = xp.fft.fft2(current_object) - object_fft[ind_min_x:ind_max_x] = 0 - object_fft[:, ind_min_y:ind_max_y] = 0 - - return xp.real(xp.fft.ifft2(object_fft)) - def _constraints( self, current_object, @@ -607,7 +596,6 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, - anti_gridding, ): """ DPC constraints operator. @@ -628,9 +616,6 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - anti_gridding: bool - If true, zero outer pixels of object fft to remove - gridding artifacts Returns -------- @@ -650,11 +635,6 @@ def _constraints( butterworth_order, ) - if anti_gridding: - current_object = self._object_anti_gridding_contraint( - current_object, - ) - return current_object def reconstruct( @@ -671,8 +651,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - anti_gridding: float = True, store_iterations: bool = False, + device: str = None, + clear_fft_cache: bool = True, ): """ Performs Iterative DPC Reconstruction: @@ -705,11 +686,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - anti_gridding: bool - If true, zero outer pixels of object fft to remove - gridding artifacts store_iterations: bool, optional If True, all reconstruction iterations will be stored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -717,18 +699,35 @@ def reconstruct( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device asnumpy = self._asnumpy # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - if reset: + if reset is True: self.error = np.inf self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -798,10 +797,10 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, - anti_gridding=anti_gridding, ) self.error_iterations.append(self.error.item()) + if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -824,9 +823,7 @@ def reconstruct( ] self.object_phase = asnumpy(self._object_phase) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -848,6 +845,8 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) if plot_convergence: spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) @@ -865,7 +864,12 @@ def _visualize_last_iteration( ] ax1 = fig.add_subplot(spec[0]) - im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) + + obj, vmin, vmax = return_scaled_histogram_ordering( + self.object_phase, vmin, vmax + ) + im = ax1.imshow(obj, extent=extent, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) + ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") @@ -880,6 +884,7 @@ def _visualize_last_iteration( errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 0826d679e..5f93ec0e8 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1093,8 +1093,6 @@ def _solve_for_center_of_mass_relative_rotation( # Minimize Curl ind_min = xp.argmin(rotation_curl).item() ind_trans_min = xp.argmin(rotation_curl_transpose).item() - self._rotation_curl = rotation_curl - self._rotation_curl_transpose = rotation_curl_transpose if rotation_curl[ind_min] <= rotation_curl_transpose[ind_trans_min]: rotation_best_deg = rotation_angles_deg[ind_min] _rotation_best_rad = rotation_angles_rad[ind_min] From e2d03e71fa6b0c72703f29c9b7be394feddb305b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 16:41:27 -0800 Subject: [PATCH 083/128] more dpc viz tweaks Former-commit-id: 5ff1dcf5cbd8b237905795a3b9ad06b688846368 --- py4DSTEM/process/phase/dpc.py | 49 ++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index 764aa4b5f..a78d823af 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -20,7 +20,6 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction -from py4DSTEM.process.phase.utils import copy_to_device from py4DSTEM.visualize.vis_special import return_scaled_histogram_ordering warnings.simplefilter(action="always", category=UserWarning) @@ -289,9 +288,7 @@ def preprocess( xp = self._xp device = self._device - xp_storage = self._xp_storage storage = self._storage - asnumpy = self._asnumpy # set additional metadata self._dp_mask = dp_mask @@ -714,7 +711,6 @@ def reconstruct( self.copy_attributes_to_device(attrs, device) xp = self._xp - xp_storage = self._xp_storage device = self._device asnumpy = self._asnumpy @@ -872,7 +868,7 @@ def _visualize_last_iteration( ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title("Reconstructed object phase") if cbar: divider = make_axes_locatable(ax1) @@ -889,6 +885,7 @@ def _visualize_last_iteration( ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -913,7 +910,6 @@ def _visualize_all_iterations( iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations """ - if not hasattr(self, "object_phase_iterations"): raise ValueError( ( @@ -922,31 +918,41 @@ def _visualize_all_iterations( ) ) - if iterations_grid == "auto": - num_iter = len(self.error_iterations) + num_iter = len(self.object_phase_iterations) + if iterations_grid == "auto": if num_iter == 1: return self._visualize_last_iteration( + fig=fig, plot_convergence=plot_convergence, cbar=cbar, **kwargs, ) + else: iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if iterations_grid[0] * iterations_grid[1] > num_iter: + raise ValueError() + auto_figsize = ( (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) if plot_convergence else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) + figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + max_iter = num_iter - 1 total_grids = np.prod(iterations_grid) - errors = self.error_iterations - phases = self.object_phase_iterations - max_iter = len(phases) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + + errors = np.array(self.error_iterations)[-num_iter:] + objects = [self.object_phase_iterations[n] for n in grid_range] extent = [ 0, @@ -973,25 +979,30 @@ def _visualize_all_iterations( ) for n, ax in enumerate(grid): + obj, vmin_n, vmax_n = return_scaled_histogram_ordering( + objects[n], vmin=vmin, vmax=vmax + ) im = ax.imshow( - phases[grid_range[n]], + obj, extent=extent, cmap=cmap, + vmin=vmin_n, + vmax=vmax_n, **kwargs, ) + ax.set_ylabel(f"x [{self._scan_units[0]}]") ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_title(f"Iter: {grid_range[n]} phase") + if cbar: grid.cbar_axes[n].colorbar(im) - ax.set_title( - f"Iteration: {grid_range[n]}\nNMSE error: {errors[grid_range[n]]:.3e}" - ) if plot_convergence: ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(len(errors)), errors, **kwargs) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_xlabel("Iteration number") - ax2.set_ylabel("Log NMSE error") + ax2.set_ylabel("NMSE error") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -1037,6 +1048,8 @@ def visualize( **kwargs, ) + self.clear_device_mem(self._device, self._clear_fft_cache) + return self @property From ca02678296b7f44a71c3d21733091f9e840cf44c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 17:05:05 -0800 Subject: [PATCH 084/128] adding vectorized flag to CoM Former-commit-id: 71191c0720f4ba0b6f98ab149f545e3da5eb2341 --- py4DSTEM/process/phase/phase_base_class.py | 75 +++++++++++++------ .../process/phase/singleslice_ptychography.py | 4 + 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 5f93ec0e8..576b35c24 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -16,7 +16,7 @@ except (ModuleNotFoundError, ImportError): cp = np -from emdfile import Array, Custom, Metadata, _read_metadata +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin @@ -588,6 +588,7 @@ def _calculate_intensities_center_of_mass( fit_function: str = "plane", com_shifts: np.ndarray = None, com_measured: np.ndarray = None, + vectorized_calculation=True, ): """ Common preprocessing function to compute and fit diffraction intensities CoM @@ -604,6 +605,8 @@ def _calculate_intensities_center_of_mass( If not None, com_shifts are fitted on the measured CoM values. com_measured: tuple of ndarrays (CoMx measured, CoMy measured) If not None, com_measured are passed as com_measured_x, com_measured_y + vectorized_calculation: bool, optional + If True (default), the calculation is vectorized Returns ------- @@ -632,15 +635,6 @@ def _calculate_intensities_center_of_mass( com_measured_x, com_measured_y = com_measured else: - # copy to device - intensities = copy_to_device(intensities, device) - - # Coordinates - kx = xp.arange(intensities.shape[-2], dtype=xp.float32) - ky = xp.arange(intensities.shape[-1], dtype=xp.float32) - kya, kxa = xp.meshgrid(ky, kx) - - # calculate CoM if dp_mask is not None: if dp_mask.shape != intensities.shape[-2:]: raise ValueError( @@ -649,19 +643,56 @@ def _calculate_intensities_center_of_mass( f"not {dp_mask.shape}" ) ) - intensities_mask = intensities * xp.asarray(dp_mask, dtype=xp.float32) - else: - intensities_mask = intensities + dp_mask = xp.asarray(dp_mask, dtype=xp.float32) - intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) - com_measured_x = ( - xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) - / intensities_sum - ) - com_measured_y = ( - xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) - / intensities_sum - ) + # Coordinates + kx = xp.arange(intensities.shape[-2], dtype=xp.float32) + ky = xp.arange(intensities.shape[-1], dtype=xp.float32) + kya, kxa = xp.meshgrid(ky, kx) + + if vectorized_calculation: + # copy to device + intensities = copy_to_device(intensities, device) + + # calculate CoM + if dp_mask is not None: + intensities_mask = intensities * dp_mask + else: + intensities_mask = intensities + + intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) + + com_measured_x = ( + xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) + / intensities_sum + ) + com_measured_y = ( + xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) + / intensities_sum + ) + + else: + sx, sy = intensities.shape[:2] + com_measured_x = xp.zeros((sx, sy), dtype=xp.float32) + com_measured_y = xp.zeros((sx, sy), dtype=xp.float32) + + # loop of dps + for rx, ry in tqdmnd( + sx, + sy, + desc="Fitting center of mass", + unit="probe position", + disable=not self._verbose, + ): + intensities_device = copy_to_device(intensities[rx, ry], device) + masked_intensity = intensities_device * dp_mask + summed_intensity = masked_intensity.sum() + com_measured_x[rx, ry] = ( + xp.sum(masked_intensity * kxa) / summed_intensity + ) + com_measured_y[rx, ry] = ( + xp.sum(masked_intensity * kya) / summed_intensity + ) if com_shifts is None: com_measured_x_np = asnumpy(com_measured_x) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 9ba285475..2485a48aa 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -198,6 +198,7 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + vectorized_com_calculation: bool = True, force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, @@ -246,6 +247,8 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_scan_sampling: float, optional Override DataCube real space scan pixel size calibrations, in Angstrom force_angular_sampling: float, optional @@ -340,6 +343,7 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, ) # estimate rotation / transpose From cf84534c863102ff90b2462636a97f56e456e68b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 17:09:04 -0800 Subject: [PATCH 085/128] dp_mask can be None, dub Former-commit-id: f664f42fd390e1ec3b0e70138fee0544978086bf --- py4DSTEM/process/phase/phase_base_class.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 576b35c24..d71011026 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -684,8 +684,9 @@ def _calculate_intensities_center_of_mass( unit="probe position", disable=not self._verbose, ): - intensities_device = copy_to_device(intensities[rx, ry], device) - masked_intensity = intensities_device * dp_mask + masked_intensity = copy_to_device(intensities[rx, ry], device) + if dp_mask is not None: + masked_intensity *= dp_mask summed_intensity = masked_intensity.sum() com_measured_x[rx, ry] = ( xp.sum(masked_intensity * kxa) / summed_intensity From ffb6cba5eb69559cea4a3f97bfd26be70ca33694 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 18:09:47 -0800 Subject: [PATCH 086/128] adding batch size in single-slice preprocess Former-commit-id: 529f8fc89533ddd3ca9104a2334d84938d0cf67c --- .../process/phase/singleslice_ptychography.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 2485a48aa..f599956f0 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -206,6 +206,7 @@ def preprocess( crop_patterns: bool = False, device: str = None, clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -264,6 +265,8 @@ def preprocess( if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- @@ -461,13 +464,23 @@ def preprocess( self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) # overlaps - positions_px_fractional = self._positions_px - xp_storage.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) - probe_overlap = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2, self._positions_px - ) + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + del shifted_probes # initialize object_fov_mask From e7b95c22806b567e8785856527fefe9f7a3d56ba Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 18:19:24 -0800 Subject: [PATCH 087/128] adding to dpc, multislice, mixedstate, multislice-mixedstate Former-commit-id: ca4d02e00f73dabbefcdd6defbb92c4a6c3ed243 --- py4DSTEM/process/phase/dpc.py | 4 +++ .../mixedstate_multislice_ptychography.py | 31 ++++++++++++++----- .../process/phase/mixedstate_ptychography.py | 27 +++++++++++----- .../process/phase/multislice_ptychography.py | 31 ++++++++++++++----- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index a78d823af..e6eed79f6 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -237,6 +237,7 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: bool = None, force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + vectorized_com_calculation: bool = True, force_com_measured: Sequence[np.ndarray] = None, plot_center_of_mass: str = "default", plot_rotation: bool = True, @@ -270,6 +271,8 @@ def preprocess( Force whether diffraction intensities need to be transposed. force_com_shifts: tuple of ndarrays (CoMx, CoMy) Force CoM fitted shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) Force CoM measured shifts plot_center_of_mass: str, optional @@ -324,6 +327,7 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, com_measured=force_com_measured, ) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index e44fc77de..99d1ff9c5 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -270,6 +270,7 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + vectorized_com_calculation: bool = True, force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, @@ -277,6 +278,7 @@ def preprocess( crop_patterns: bool = False, device: str = None, clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -328,6 +330,8 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_scan_sampling: float, optional Override DataCube real space scan pixel size calibrations, in Angstrom force_angular_sampling: float, optional @@ -343,6 +347,8 @@ def preprocess( If not None, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- @@ -391,6 +397,7 @@ def preprocess( vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, ) # calibrations @@ -549,13 +556,23 @@ def preprocess( ) # overlaps - positions_px_fractional = self._positions_px - xp_storage.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) - probe_overlap = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2, self._positions_px - ) + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + del shifted_probes if object_fov_mask is None: diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 17e220e51..3c838fff5 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -223,6 +223,7 @@ def preprocess( crop_patterns: bool = False, device: str = None, clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -289,6 +290,8 @@ def preprocess( if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- @@ -484,13 +487,23 @@ def preprocess( self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) # overlaps - positions_px_fractional = self._positions_px - xp_storage.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) - probe_overlap = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2, self._positions_px - ) + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + del shifted_probes if object_fov_mask is None: diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index cd406139b..c18890b85 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -244,6 +244,7 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + vectorized_com_calculation: bool = True, force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, @@ -251,6 +252,7 @@ def preprocess( crop_patterns: bool = False, device: str = None, clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -302,6 +304,8 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_scan_sampling: float, optional Override DataCube real space scan pixel size calibrations, in Angstrom force_angular_sampling: float, optional @@ -317,6 +321,8 @@ def preprocess( If not None, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- @@ -395,6 +401,7 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, ) # estimate rotation / transpose @@ -523,13 +530,23 @@ def preprocess( ) # overlaps - positions_px_fractional = self._positions_px - xp_storage.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) - probe_overlap = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2, self._positions_px - ) + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + del shifted_probes if object_fov_mask is None: From 1c82708bf36b9218ca107df61c8274e371456777 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 18:33:10 -0800 Subject: [PATCH 088/128] adding some basic device cleanup to parallax, no storage yer Former-commit-id: f15278a5484a3f3baf0cb920a619c6d311a75ff9 --- py4DSTEM/process/phase/parallax.py | 71 ++++++++++++++++++------------ 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index ee1c6f8b5..3547bb60e 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -82,26 +82,20 @@ def __init__( verbose: bool = False, object_padding_px: Tuple[int, int] = (32, 32), device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "parallax_reconstruction", ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy + if storage != device: + raise NotImplementedError() - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) self.set_save_defaults() @@ -111,7 +105,6 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose - self._device = device self._object_padding_px = object_padding_px self._preprocessed = False @@ -273,6 +266,9 @@ def preprocess( plot_average_bf: bool = True, realspace_mask: np.ndarray = None, apply_realspace_mask_to_stack: bool = True, + vectorized_com_calculation: bool = True, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -308,6 +304,12 @@ def preprocess( apply_realspace_mask_to_stack: bool, optional If this value is set to true, output BF images will be masked by the edge filter and realspace_mask if it is passed in. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -315,7 +317,11 @@ def preprocess( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device asnumpy = self._asnumpy if self._datacube is None: @@ -332,6 +338,8 @@ def preprocess( require_calibrations=True, ) + self._intensities = xp.asarray(self._intensities) + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) self._scan_shape = np.array(self._intensities.shape[:2]) @@ -350,6 +358,7 @@ def preprocess( fit_function=descan_correction_fit_function, com_shifts=None, com_measured=None, + vectorized_calculation=vectorized_com_calculation, ) com_fitted_x = asnumpy(com_fitted_x) @@ -357,7 +366,6 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - # center_x, center_y = self._region_of_interest_shape / 2 center_x = com_fitted_x.mean() center_y = com_fitted_y.mean() @@ -719,11 +727,9 @@ def preprocess( ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") plt.tight_layout() - self._preprocessed = True - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self._preprocessed = True + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -741,6 +747,8 @@ def reconstruct( plot_aligned_bf: bool = True, plot_convergence: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, **kwargs, ): """ @@ -773,6 +781,10 @@ def reconstruct( If True, the convergence error is also plotted reset: bool, optional If True, the reconstruction is reset + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -780,6 +792,9 @@ def reconstruct( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp asnumpy = self._asnumpy @@ -790,6 +805,7 @@ def reconstruct( self._stack_mask = self._stack_mask_initial.copy() self._recon_mask = self._recon_mask_initial.copy() self._xy_shifts = self._xy_shifts_initial.copy() + elif reset is None: if hasattr(self, "error_iterations"): warnings.warn( @@ -1034,9 +1050,7 @@ def reconstruct( self.recon_BF = asnumpy(self._recon_BF) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -1742,6 +1756,8 @@ def subpixel_alignment( spec.tight_layout(fig) + self.clear_device_mem(self._device, self._clear_fft_cache) + def _interpolate_array( self, image, @@ -2375,9 +2391,7 @@ def score_CTF(coefs): + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) def _calculate_CTF(self, alpha_shape, sampling, *coefs): xp = self._xp @@ -2576,6 +2590,8 @@ def aberration_correct( ax.set_xlabel("y [A]") ax.set_title("Parallax-Corrected Phase Image") + self.clear_device_mem(self._device, self._clear_fft_cache) + def depth_section( self, depth_angstroms=np.arange(-250, 260, 100), @@ -2853,7 +2869,8 @@ def show_shifts( dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( - xp.arange(self._dp_mean.shape[1]), xp.arange(self._dp_mean.shape[0]) + xp.arange(self._region_of_interest_shape[1]), + xp.arange(self._region_of_interest_shape[0]), ) freq_mask = xp.logical_and(xx % plot_arrow_freq == 0, yy % plot_arrow_freq == 0) masked_ind = xp.logical_and(freq_mask, self._dp_mask) From f72d9a8fa61466cc3d67ffa1efb52ba327acd646 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 19:13:24 -0800 Subject: [PATCH 089/128] magnetic ptycho preprocess storage support Former-commit-id: 54cea8d5ce0a54200ed55746fa3bb18ed00d6e37 --- .../process/phase/magnetic_ptychography.py | 112 +++++++++++------- .../process/phase/ptychographic_methods.py | 8 +- 2 files changed, 75 insertions(+), 45 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 2bf556d21..833951f44 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -38,6 +38,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -106,6 +107,10 @@ class MagneticPtychography( If True, class methods will inherit this and print additional information device: str, optional Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -136,27 +141,18 @@ def __init__( object_type: str = "complex", verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "magnetic_ptychographic_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): @@ -193,7 +189,6 @@ def __init__( self._object_padding_px = object_padding_px self._positions_mask = positions_mask self._verbose = verbose - self._device = device self._preprocessed = False # Class-specific Metadata @@ -214,12 +209,16 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + vectorized_com_calculation: bool = True, force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -268,6 +267,8 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_scan_sampling: float, optional Override DataCube real space scan pixel size calibrations, in Angstrom force_angular_sampling: float, optional @@ -279,13 +280,25 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- self: PtychographicReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -382,7 +395,9 @@ def preprocess( for q, s in zip(roi_shape, padded_diffraction_intensities_shape) ) - self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) if region_of_interest_shape is not None: self._resample_exit_waves = True @@ -463,6 +478,7 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, ) # estimate rotation / transpose using first measurement @@ -476,8 +492,6 @@ def preprocess( self._rotation_best_transpose, _com_x, _com_y, - com_x, - com_y, ) = self._solve_for_center_of_mass_relative_rotation( com_measured_x, com_measured_y, @@ -496,8 +510,9 @@ def preprocess( # corner-center amplitudes idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] + ( - self._amplitudes[idx_start:idx_end], + amplitudes, mean_diffraction_intensity_temp, self._crop_mask, ) = self._normalize_diffraction_intensities( @@ -510,8 +525,12 @@ def preprocess( self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + del ( intensities, + amplitudes, com_measured_x, com_measured_y, com_fitted_x, @@ -553,17 +572,18 @@ def preprocess( self._object_shape = self._object.shape[-2:] # center probe positions - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp.float32 + ) for index in range(self._num_measurements): idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] - self._positions_px = self._positions_px_all[idx_start:idx_end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() @@ -606,16 +626,24 @@ def preprocess( )._evaluate_ctf() # overlaps - idx_end = self._cum_probes_per_measurement[1] - self._positions_px = self._positions_px_all[0:idx_end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probes_all[0], self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probes_all[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes # initialize object_fov_mask if object_fov_mask is None: @@ -624,10 +652,13 @@ def preprocess( self._object_fov_mask = asnumpy( probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() ) + del probe_overlap_blurred else: self._object_fov_mask = np.asarray(object_fov_mask) self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) @@ -670,7 +701,7 @@ def preprocess( ax1.set_title("Initial probe intensity") ax2.imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="gray", ) @@ -689,10 +720,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index c1c10e6ab..280a6a392 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -3264,10 +3264,10 @@ def _return_average_positions( self, positions=None, cum_probes_per_measurement=None ): """Average positions estimate""" - xp = self._xp + xp_storage = self._xp_storage if positions is not None: - _pos = xp.asarray(positions) + _pos = xp_storage.asarray(positions) else: if not hasattr(self, "_positions_px_all"): return None @@ -3282,7 +3282,9 @@ def _return_average_positions( if np.any(num_probes_per_measurement != num_probes_per_measurement[0]): return None - avg_positions = xp.zeros((num_probes_per_measurement[0], 2), dtype=xp.float32) + avg_positions = xp_storage.zeros( + (num_probes_per_measurement[0], 2), dtype=xp_storage.float32 + ) for index in range(num_measurements): start_idx = cum_probes_per_measurement[index] From ab456f548e5507573d5f8a7e9ff4e32ade277270 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 20:45:51 -0800 Subject: [PATCH 090/128] added full storage support for magnetic ptycho Former-commit-id: 3c94a2da8012b3a71b9f2cba5289d54085e28aca --- .../process/phase/magnetic_ptychography.py | 185 +++++++++++------- 1 file changed, 113 insertions(+), 72 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 833951f44..c835eaf01 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -724,7 +724,13 @@ def preprocess( return self - def _overlap_projection(self, current_object, shifted_probes): + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): """ Ptychographic overlap projection method. @@ -747,21 +753,19 @@ def _overlap_projection(self, current_object, shifted_probes): xp = self._xp - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - object_patches = xp.empty( - (self._num_measurements,) + shifted_probes.shape, dtype=xp.complex64 + (self._num_measurements,) + shifted_probes.shape, dtype=current_object.dtype ) - object_patches[0] = complex_object[ - 0, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + object_patches[0] = current_object[ + 0, vectorized_patch_indices_row, vectorized_patch_indices_col ] - object_patches[1] = complex_object[ - 1, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + object_patches[1] = current_object[ + 1, vectorized_patch_indices_row, vectorized_patch_indices_col ] + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + overlap_base = shifted_probes * object_patches[0] match (self._recon_mode, self._active_measurement_index): @@ -782,6 +786,7 @@ def _gradient_descent_adjoint( current_probe, object_patches, shifted_probes, + positions_px, exit_waves, step_size, normalization_min, @@ -824,7 +829,8 @@ def _gradient_descent_adjoint( probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( - probe_electrostatic_abs**2 + probe_electrostatic_abs**2, + positions_px, ) probe_electrostatic_normalization = 1 / xp.sqrt( 1e-16 @@ -834,7 +840,8 @@ def _gradient_descent_adjoint( probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2 + probe_magnetic_abs**2, + positions_px, ) probe_magnetic_normalization = 1 / xp.sqrt( 1e-16 @@ -878,7 +885,8 @@ def _gradient_descent_adjoint( * electrostatic_conj * probe_conj * exit_waves - ) + ), + positions_px, ) # i exp(-i v) exp(i m) P* @@ -887,13 +895,15 @@ def _gradient_descent_adjoint( else: # M P* electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * object_patches[1] * exit_waves + probe_conj * object_patches[1] * exit_waves, + positions_px, ) # V* P* magnetic_update = xp.conj( self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves + probe_conj * electrostatic_conj * exit_waves, + positions_px, ) ) @@ -926,7 +936,8 @@ def _gradient_descent_adjoint( * electrostatic_conj * probe_conj * exit_waves - ) + ), + positions_px, ) # -i exp(-i v) exp(-i m) P* @@ -935,12 +946,14 @@ def _gradient_descent_adjoint( else: # M* P* electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves + probe_conj * magnetic_conj * exit_waves, + positions_px, ) # V* P* magnetic_update = self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves + probe_conj * electrostatic_conj * exit_waves, + positions_px, ) current_object[0] += ( @@ -963,7 +976,8 @@ def _gradient_descent_adjoint( case (1, 1) | (2, 0): # neutral probe_abs = xp.abs(shifted_probes) probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 + probe_abs**2, + positions_px, ) probe_normalization = 1 / xp.sqrt( 1e-16 @@ -974,13 +988,15 @@ def _gradient_descent_adjoint( if self._object_type == "potential": # -i exp(-i v) P* electrostatic_update = self._sum_overlapping_patches_bincounts( - xp.real(-1j * electrostatic_conj * probe_conj * exit_waves) + xp.real(-1j * electrostatic_conj * probe_conj * exit_waves), + positions_px, ) else: # P* electrostatic_update = self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves + probe_conj * exit_waves, + positions_px, ) current_object[0] += ( @@ -1086,7 +1102,7 @@ def reconstruct( normalization_min: float = 1, positions_step_size: float = 0.9, pure_phase_object_iter: int = 0, - fix_com: bool = True, + fix_probe_com: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, constrain_probe_amplitude_iter: int = 0, @@ -1096,6 +1112,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, global_affine_transformation: bool = True, @@ -1123,6 +1140,8 @@ def reconstruct( collective_measurement_updates: bool = True, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -1159,7 +1178,7 @@ def reconstruct( Positions update step size pure_phase_object_iter: float, optional Number of iterations where object amplitude is set to unity - fix_com: bool, optional + fix_probe_com: bool, optional If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate @@ -1179,6 +1198,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -1234,14 +1255,34 @@ def reconstruct( If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: PtychographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy if not collective_measurement_updates and self._verbose: warnings.warn( @@ -1287,7 +1328,7 @@ def reconstruct( ) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -1323,6 +1364,7 @@ def reconstruct( if collective_measurement_updates: collective_object = xp.zeros_like(self._object) + # randomize measurement_indices = np.arange(self._num_measurements) np.random.shuffle(measurement_indices) @@ -1344,43 +1386,31 @@ def reconstruct( ] num_diffraction_patterns = end_idx - start_idx - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) + shuffled_indices = np.arange(start_idx, end_idx) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_idx:end_idx].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_idx:end_idx - ].copy()[shuffled_indices] - for start, end in generate_batches( num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px ) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) - amplitudes = self._amplitudes[start_idx:end_idx][ - shuffled_indices[start:end] - ] + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( @@ -1391,8 +1421,11 @@ def reconstruct( batch_error, ) = self._forward( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, _probe, - amplitudes, + positions_px_fractional, + amplitudes_device, self._exit_waves, use_projection_scheme=use_projection_scheme, projection_a=projection_a, @@ -1406,6 +1439,7 @@ def reconstruct( _probe, object_patches, shifted_probes, + positions_px, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1417,13 +1451,17 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( + self._positions_px_all[ + batch_indices + ] = self._position_correction( self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, shifted_probes, overlap, - amplitudes, - self._positions_px, - self._positions_px_initial, + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -1444,15 +1482,12 @@ def reconstruct( error += measurement_error # constraints - self._positions_px_all[start_idx:end_idx] = positions_px.copy()[ - unshuffled_indices - ] if collective_measurement_updates: # probe and positions _probe = self._probe_constraints( _probe, - fix_com=fix_com and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -1471,11 +1506,12 @@ def reconstruct( initial_probe_aperture=_probe_initial_aperture, ) - self._positions_px_all[ - start_idx:end_idx - ] = self._positions_constraints( - self._positions_px_all[start_idx:end_idx], + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com + and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, ) @@ -1484,12 +1520,13 @@ def reconstruct( ( self._object, _probe, - self._positions_px_all[start_idx:end_idx], + self._positions_px_all[batch_indices], ) = self._constraints( self._object, _probe, - self._positions_px_all[start_idx:end_idx], - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -1507,6 +1544,8 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com + and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma_m is not None, @@ -1577,9 +1616,11 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -1850,8 +1891,8 @@ def _visualize_last_iteration( @property def object_cropped(self): """Cropped and rotated object""" - - cropped_e = self._crop_rotate_object_fov(self._object[0]) - cropped_m = self._crop_rotate_object_fov(self._object[1]) + avg_pos = self._return_average_positions() + cropped_e = self._crop_rotate_object_fov(self._object[0], positions_px=avg_pos) + cropped_m = self._crop_rotate_object_fov(self._object[1], positions_px=avg_pos) return np.array([cropped_e, cropped_m]) From 53a7e2fd7ab5c2165c0a0648a1f27ba38db90fa2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 21:51:46 -0800 Subject: [PATCH 091/128] added storage to ptycho tomo Former-commit-id: 5f79bb18f87b11da659e2422f6d789ade251555a --- .../process/phase/ptychographic_tomography.py | 277 +++++++++++------- 1 file changed, 175 insertions(+), 102 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 78444b39b..18967c7f2 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -38,6 +38,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -109,13 +110,17 @@ class PtychographicTomography( If None, initialized to a grid scan centered along tilt axis verbose: bool, optional If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') positions_mask: np.ndarray, optional Boolean real space mask to select positions to ignore in reconstruction + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name kwargs: @@ -144,27 +149,18 @@ def __init__( initial_scan_positions: Sequence[np.ndarray] = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "ptychographic-tomography_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): @@ -203,7 +199,6 @@ def __init__( self._object_padding_px = object_padding_px self._positions_mask = positions_mask self._verbose = verbose - self._device = device self._preprocessed = False # Class-specific Metadata @@ -224,6 +219,7 @@ def preprocess( diffraction_patterns_rotate_degrees: float = None, diffraction_patterns_transpose: bool = None, force_com_shifts: Sequence[float] = None, + vectorized_com_calculation: bool = True, force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, @@ -231,6 +227,9 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, main_tilt_axis: str = "vertical", + device: str = None, + clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -269,6 +268,8 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. One tuple per tilt. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_scan_sampling: float, optional Override DataCube real space scan pixel size calibrations, in Angstrom force_angular_sampling: float, optional @@ -284,13 +285,25 @@ def preprocess( The default, 'vertical' (first scan dimension), results in object size (q,p,q), 'horizontal' (second scan dimension) results in object size (p,p,q), any other value (e.g. None) results in object size (max(p,q),p,q). + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- self: OverlapTomographicReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -346,7 +359,9 @@ def preprocess( for q, s in zip(roi_shape, padded_diffraction_intensities_shape) ) - self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) if region_of_interest_shape is not None: self._resample_exit_waves = True @@ -424,13 +439,14 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, ) # corner-center amplitudes idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] ( - self._amplitudes[idx_start:idx_end], + amplitudes, mean_diffraction_intensity_temp, self._crop_mask, ) = self._normalize_diffraction_intensities( @@ -443,8 +459,12 @@ def preprocess( self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + del ( intensities, + amplitudes, com_measured_x, com_measured_y, com_fitted_x, @@ -483,23 +503,29 @@ def preprocess( self._num_voxels = self._object.shape[0] # center probe positions - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp.float32 + ) for index in range(self._num_measurements): idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] - self._positions_px = self._positions_px_all[idx_start:idx_end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() self._positions_initial_all[:, 0] *= self.sampling[0] self._positions_initial_all[:, 1] *= self.sampling[1] + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + # initialize probe self._probes_all = [] self._probes_all_initial = [] @@ -551,6 +577,9 @@ def preprocess( ) # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) old_rot_matrix = np.eye(3) # identity @@ -558,6 +587,7 @@ def preprocess( for index in range(self._num_measurements): idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] + rot_matrix = self._tilt_orientation_matrices[index] probe_overlap_3D = self._rotate_zxy_volume( @@ -565,17 +595,29 @@ def preprocess( rot_matrix @ old_rot_matrix.T, ) - self._positions_px = self._positions_px_all[idx_start:idx_end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probes_all[index], self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts( - probe_intensities - ) + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + num_diffraction_patterns = idx_end - idx_start + shuffled_indices = np.arange(idx_start, idx_end) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + shifted_probes = fft_shift( + self._probes_all[index], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes probe_overlap_3D += probe_overlap[None] old_rot_matrix = rot_matrix @@ -594,20 +636,29 @@ def preprocess( else: self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[ - : self._cum_probes_per_measurement[1] - ] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probes_all[0], self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) @@ -680,7 +731,7 @@ def preprocess( ax2.set_title("Propagated probe intensity") ax3.imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="Greys_r", ) @@ -699,10 +750,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -719,7 +767,7 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, - fix_com: bool = True, + fix_probe_com: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, constrain_probe_amplitude_iter: int = 0, @@ -729,6 +777,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, global_affine_transformation: bool = True, @@ -752,6 +801,8 @@ def reconstruct( store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -786,7 +837,7 @@ def reconstruct( Probe normalization minimum as a fraction of the maximum overlap intensity positions_step_size: float, optional Positions update step size - fix_com: bool, optional + fix_probe_com: bool, optional If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate @@ -806,6 +857,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -851,14 +904,35 @@ def reconstruct( If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: OverlapTomographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy # set and report reconstruction method ( @@ -893,7 +967,7 @@ def reconstruct( ) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -950,43 +1024,31 @@ def reconstruct( ] num_diffraction_patterns = end_idx - start_idx - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) + shuffled_indices = np.arange(start_idx, end_idx) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_idx:end_idx].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_idx:end_idx - ].copy()[shuffled_indices] - for start, end in generate_batches( num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px ) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) - amplitudes = self._amplitudes[start_idx:end_idx][ - shuffled_indices[start:end] - ] + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( @@ -997,8 +1059,11 @@ def reconstruct( batch_error, ) = self._forward( object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, _probe, - amplitudes, + positions_px_fractional, + amplitudes_device, self._exit_waves, use_projection_scheme, projection_a, @@ -1012,6 +1077,7 @@ def reconstruct( _probe, object_patches, shifted_probes, + positions_px, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1021,13 +1087,17 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - object_sliced, + self._positions_px_all[ + batch_indices + ] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, shifted_probes, overlap, - amplitudes, - self._positions_px, - self._positions_px_initial, + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -1059,15 +1129,12 @@ def reconstruct( error += measurement_error # constraints - self._positions_px_all[start_idx:end_idx] = positions_px.copy()[ - unshuffled_indices - ] if collective_measurement_updates: # probe and positions _probe = self._probe_constraints( _probe, - fix_com=fix_com and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -1086,11 +1153,12 @@ def reconstruct( initial_probe_aperture=_probe_initial_aperture, ) - self._positions_px_all[ - start_idx:end_idx - ] = self._positions_constraints( - self._positions_px_all[start_idx:end_idx], + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com + and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, ) @@ -1099,12 +1167,13 @@ def reconstruct( ( self._object, _probe, - self._positions_px_all[start_idx:end_idx], + self._positions_px_all[batch_indices], ) = self._constraints( self._object, _probe, - self._positions_px_all[start_idx:end_idx], - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -1122,6 +1191,8 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com + and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, @@ -1184,8 +1255,10 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(device, self._clear_fft_cache) return self From 73adabc818db8ba65b16897ca4afc204c5290af8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 22:15:48 -0800 Subject: [PATCH 092/128] added storage support to magnetic ptycho-tomo Former-commit-id: d08202fc1114062bec0d7196d85ad1bc8b1ab9a7 --- .../magnetic_ptychographic_tomography.py | 281 +++++++++++------- .../process/phase/magnetic_ptychography.py | 2 +- .../mixedstate_multislice_ptychography.py | 4 +- .../process/phase/mixedstate_ptychography.py | 4 +- .../process/phase/multislice_ptychography.py | 4 +- .../process/phase/ptychographic_tomography.py | 10 +- .../process/phase/singleslice_ptychography.py | 4 +- 7 files changed, 199 insertions(+), 110 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 2c9e8f0b2..7f083de09 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -43,6 +43,7 @@ from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin from py4DSTEM.process.phase.utils import ( ComplexProbe, + copy_to_device, fft_shift, generate_batches, polar_aliases, @@ -115,8 +116,6 @@ class MagneticPtychographicTomography( If None, initialized to a grid scan centered along tilt axis verbose: bool, optional If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -124,12 +123,22 @@ class MagneticPtychographicTomography( Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls kwargs: Provide the aberration coefficients as keyword arguments. """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _class_specific_metadata = ( + "_num_slices", + "_tilt_orientation_matrices", + "_num_measurements", + ) def __init__( self, @@ -150,27 +159,18 @@ def __init__( initial_scan_positions: Sequence[np.ndarray] = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "magnetic-ptychographic-tomography_reconstruction", **kwargs, ): Custom.__init__(self, name=name) - if device == "cpu": - import scipy + if storage is None: + storage = device - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): @@ -209,7 +209,6 @@ def __init__( self._object_padding_px = object_padding_px self._positions_mask = positions_mask self._verbose = verbose - self._device = device self._preprocessed = False # Class-specific Metadata @@ -230,12 +229,16 @@ def preprocess( diffraction_patterns_rotate_degrees: float = None, diffraction_patterns_transpose: bool = None, force_com_shifts: Sequence[float] = None, + vectorized_com_calculation: bool = True, progress_bar: bool = True, force_scan_sampling: float = None, force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = True, + max_batch_size: int = None, **kwargs, ): """ @@ -274,6 +277,8 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. One tuple per tilt. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_scan_sampling: float, optional Override DataCube real space scan pixel size calibrations, in Angstrom force_angular_sampling: float, optional @@ -285,13 +290,25 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps Returns -------- self: OverlapTomographicReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage asnumpy = self._asnumpy # set additional metadata @@ -347,7 +364,9 @@ def preprocess( for q, s in zip(roi_shape, padded_diffraction_intensities_shape) ) - self._amplitudes = xp.empty((self._num_diffraction_patterns,) + roi_shape) + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) if region_of_interest_shape is not None: self._resample_exit_waves = True @@ -425,13 +444,14 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, ) # corner-center amplitudes idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] ( - self._amplitudes[idx_start:idx_end], + amplitudes, mean_diffraction_intensity_temp, self._crop_mask, ) = self._normalize_diffraction_intensities( @@ -444,8 +464,12 @@ def preprocess( self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + del ( intensities, + amplitudes, com_measured_x, com_measured_y, com_fitted_x, @@ -489,23 +513,29 @@ def preprocess( self._num_voxels = self._object.shape[1] # center probe positions - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) for index in range(self._num_measurements): idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] - self._positions_px = self._positions_px_all[idx_start:idx_end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - self._positions_px_all[idx_start:idx_end] = self._positions_px.copy() + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() self._positions_px_initial_all = self._positions_px_all.copy() self._positions_initial_all = self._positions_px_initial_all.copy() self._positions_initial_all[:, 0] *= self.sampling[0] self._positions_initial_all[:, 1] *= self.sampling[1] + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + # initialize probe self._probes_all = [] self._probes_all_initial = [] @@ -552,6 +582,9 @@ def preprocess( ) # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object[0]) old_rot_matrix = np.eye(3) # identity @@ -559,6 +592,7 @@ def preprocess( for index in range(self._num_measurements): idx_start = self._cum_probes_per_measurement[index] idx_end = self._cum_probes_per_measurement[index + 1] + rot_matrix = self._tilt_orientation_matrices[index] probe_overlap_3D = self._rotate_zxy_volume( @@ -566,17 +600,29 @@ def preprocess( rot_matrix @ old_rot_matrix.T, ) - self._positions_px = self._positions_px_all[idx_start:idx_end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probes_all[index], self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts( - probe_intensities - ) + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + num_diffraction_patterns = idx_end - idx_start + shuffled_indices = np.arange(idx_start, idx_end) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + shifted_probes = fft_shift( + self._probes_all[index], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes probe_overlap_3D += probe_overlap[None] old_rot_matrix = rot_matrix @@ -595,20 +641,29 @@ def preprocess( else: self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[ - : self._cum_probes_per_measurement[1] - ] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probes_all[0], self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) chroma_boost = kwargs.pop("chroma_boost", 1) @@ -681,7 +736,7 @@ def preprocess( ax2.set_title("Propagated probe intensity") ax3.imshow( - asnumpy(probe_overlap), + probe_overlap, extent=extent, cmap="Greys_r", ) @@ -700,10 +755,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(device, self._clear_fft_cache) return self @@ -796,7 +848,7 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, - fix_com: bool = True, + fix_probe_com: bool = True, fix_probe_iter: int = 0, fix_probe_aperture_iter: int = 0, constrain_probe_amplitude_iter: int = 0, @@ -806,6 +858,7 @@ def reconstruct( constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, + fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, global_affine_transformation: bool = True, @@ -832,6 +885,8 @@ def reconstruct( store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = True, ): """ Ptychographic reconstruction main method. @@ -866,7 +921,7 @@ def reconstruct( Probe normalization minimum as a fraction of the maximum overlap intensity positions_step_size: float, optional Positions update step size - fix_com: bool, optional + fix_probe_com: bool, optional If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate @@ -886,6 +941,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -933,14 +990,35 @@ def reconstruct( If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- self: OverlapMagneticTomographicReconstruction Self to accommodate chaining """ - asnumpy = self._asnumpy + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy if not collective_measurement_updates and self._verbose: warnings.warn( @@ -986,7 +1064,7 @@ def reconstruct( ) if max_batch_size is not None: - xp.random.seed(seed_random) + np.random.seed(seed_random) else: max_batch_size = self._num_diffraction_patterns @@ -1061,43 +1139,31 @@ def reconstruct( ] num_diffraction_patterns = end_idx - start_idx - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) + shuffled_indices = np.arange(start_idx, end_idx) # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_idx:end_idx].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_idx:end_idx - ].copy()[shuffled_indices] - for start, end in generate_batches( num_diffraction_patterns, max_batch=max_batch_size ): # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px ) ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) - amplitudes = self._amplitudes[start_idx:end_idx][ - shuffled_indices[start:end] - ] + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) # forward operator ( @@ -1108,8 +1174,11 @@ def reconstruct( batch_error, ) = self._forward( object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, _probe, - amplitudes, + positions_px_fractional, + amplitudes_device, self._exit_waves, use_projection_scheme, projection_a, @@ -1123,6 +1192,7 @@ def reconstruct( _probe, object_patches, shifted_probes, + positions_px, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1132,13 +1202,17 @@ def reconstruct( # position correction if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( + self._positions_px_all[ + batch_indices + ] = self._position_correction( object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, shifted_probes, overlap, - amplitudes, - self._positions_px, - self._positions_px_initial, + amplitudes_device, + positions_px, + positions_px_initial, positions_step_size, max_position_update_distance, max_position_total_distance, @@ -1173,15 +1247,12 @@ def reconstruct( error += measurement_error # constraints - self._positions_px_all[start_idx:end_idx] = positions_px.copy()[ - unshuffled_indices - ] if collective_measurement_updates: # probe and positions _probe = self._probe_constraints( _probe, - fix_com=fix_com and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -1200,11 +1271,12 @@ def reconstruct( initial_probe_aperture=_probe_initial_aperture, ) - self._positions_px_all[ - start_idx:end_idx - ] = self._positions_constraints( - self._positions_px_all[start_idx:end_idx], + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com + and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, ) @@ -1213,12 +1285,13 @@ def reconstruct( ( self._object, _probe, - self._positions_px_all[start_idx:end_idx], + self._positions_px_all[batch_indices], ) = self._constraints( self._object, _probe, - self._positions_px_all[start_idx:end_idx], - fix_com=fix_com and a0 >= fix_probe_iter, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, @@ -1236,6 +1309,8 @@ def reconstruct( fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, + fix_positions_com=fix_positions_com + and a0 >= fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma_m is not None, @@ -1304,9 +1379,11 @@ def reconstruct( self.probe = self.probe_centered self.error = error.item() - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index c835eaf01..9812c1783 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -573,7 +573,7 @@ def preprocess( # center probe positions self._positions_px_all = xp_storage.asarray( - self._positions_px_all, dtype=xp.float32 + self._positions_px_all, dtype=xp_storage.float32 ) for index in range(self._num_measurements): diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 99d1ff9c5..26bea876d 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -512,7 +512,9 @@ def preprocess( self._object_shape = self._object.shape[-2:] # center probe positions - self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px -= ( self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 3c838fff5..ddd853d3d 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -453,7 +453,9 @@ def preprocess( self._object_shape = self._object.shape # center probe positions - self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px -= ( self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index c18890b85..02740dec9 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -486,7 +486,9 @@ def preprocess( self._object_shape = self._object.shape[-2:] # center probe positions - self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px -= ( self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 18967c7f2..ae4d31913 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -128,7 +128,11 @@ class PtychographicTomography( """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _class_specific_metadata = ( + "_num_slices", + "_tilt_orientation_matrices", + "_num_measurements", + ) def __init__( self, @@ -504,7 +508,7 @@ def preprocess( # center probe positions self._positions_px_all = xp_storage.asarray( - self._positions_px_all, dtype=xp.float32 + self._positions_px_all, dtype=xp_storage.float32 ) for index in range(self._num_measurements): @@ -1090,7 +1094,7 @@ def reconstruct( self._positions_px_all[ batch_indices ] = self._position_correction( - self._object, + object_sliced, vectorized_patch_indices_row, vectorized_patch_indices_col, shifted_probes, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index f599956f0..66ee248e7 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -430,7 +430,9 @@ def preprocess( self._object_shape = self._object.shape # center probe positions - self._positions_px = xp_storage.asarray(self._positions_px, dtype=xp.float32) + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) self._positions_px_initial_com = self._positions_px.mean(0) self._positions_px -= ( self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 From 66f20a3943f159c56e91808f8660a6ac580f6419 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 11 Jan 2024 22:28:08 -0800 Subject: [PATCH 093/128] transferring parallax bug Steph found in phase_contrast Former-commit-id: 53e7c0666b1f7acd0128caaecc57d5e7e8172082 --- py4DSTEM/process/phase/parallax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 3547bb60e..2d3fc4f33 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -1138,7 +1138,7 @@ def subpixel_alignment( if self._DF_upsample_limit < 1: warnings.warn( ( - f"Dark-field upsampling limit of {self._DF_upsampling_limit:.2f} " + f"Dark-field upsampling limit of {self._DF_upsample_limit:.2f} " "is less than 1, implying a scan step-size smaller than Nyquist. " "setting to 1." ), From 50e6c737ff3325e8f4cb9d9409cd38981d7cdd2e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 11:47:50 -0800 Subject: [PATCH 094/128] tweaks to clear_fft_cache Former-commit-id: b906763859386d85fc5921769b0b26891cadfdff --- py4DSTEM/process/phase/dpc.py | 8 ++++---- .../phase/magnetic_ptychographic_tomography.py | 8 ++++---- py4DSTEM/process/phase/magnetic_ptychography.py | 8 ++++---- .../phase/mixedstate_multislice_ptychography.py | 8 ++++---- py4DSTEM/process/phase/mixedstate_ptychography.py | 8 ++++---- py4DSTEM/process/phase/multislice_ptychography.py | 8 ++++---- py4DSTEM/process/phase/parallax.py | 12 ++++-------- py4DSTEM/process/phase/phase_base_class.py | 7 ++++--- py4DSTEM/process/phase/ptychographic_methods.py | 3 +-- py4DSTEM/process/phase/ptychographic_tomography.py | 8 ++++---- .../process/phase/ptychographic_visualizations.py | 2 +- py4DSTEM/process/phase/singleslice_ptychography.py | 8 ++++---- 12 files changed, 42 insertions(+), 46 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index e6eed79f6..4997f733a 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -242,7 +242,7 @@ def preprocess( plot_center_of_mass: str = "default", plot_rotation: bool = True, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -391,7 +391,7 @@ def preprocess( self._ky_op = -1j * 0.25 * kya * k_den self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -654,7 +654,7 @@ def reconstruct( butterworth_order: float = 2, store_iterations: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Performs Iterative DPC Reconstruction: @@ -823,7 +823,7 @@ def reconstruct( ] self.object_phase = asnumpy(self._object_phase) - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 7f083de09..1fccee78b 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -237,7 +237,7 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -755,7 +755,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -886,7 +886,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -1383,7 +1383,7 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 9812c1783..5a382406a 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -217,7 +217,7 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -720,7 +720,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -1141,7 +1141,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -1620,7 +1620,7 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 26bea876d..fd302bc35 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -277,7 +277,7 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -679,7 +679,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -739,7 +739,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -1093,6 +1093,6 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index ddd853d3d..b49dc0c34 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -222,7 +222,7 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -583,7 +583,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -637,7 +637,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -972,6 +972,6 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 02740dec9..3624db7ac 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -251,7 +251,7 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -653,7 +653,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -712,7 +712,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -1068,6 +1068,6 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 2d3fc4f33..a0eea59df 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -268,7 +268,7 @@ def preprocess( apply_realspace_mask_to_stack: bool = True, vectorized_com_calculation: bool = True, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -729,7 +729,7 @@ def preprocess( plt.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -748,7 +748,7 @@ def reconstruct( plot_convergence: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -1050,7 +1050,7 @@ def reconstruct( self.recon_BF = asnumpy(self._recon_BF) - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -2557,10 +2557,6 @@ def aberration_correct( self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) self.recon_phase_corrected = asnumpy(self._recon_phase_corrected) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index d71011026..103d8f467 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -56,8 +56,10 @@ def set_device(self, device, clear_fft_cache): Self to enable chaining """ - if device is None: + if clear_fft_cache is not None: self._clear_fft_cache = clear_fft_cache + + if device is None: return self if device == "cpu": @@ -76,7 +78,6 @@ def set_device(self, device, clear_fft_cache): raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") self._device = device - self._clear_fft_cache = clear_fft_cache return self @@ -114,7 +115,7 @@ def set_storage(self, storage): def clear_device_mem(self, device, clear_fft_cache): """ """ if device == "gpu": - if clear_fft_cache is True: + if clear_fft_cache: cache = get_plan_cache() cache.clear() diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 280a6a392..e5d90e3e5 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -2416,7 +2416,6 @@ def show_transmitted_probe( xp = self._xp xp_storage = self._xp_storage - device = self._device asnumpy = self._asnumpy if max_batch_size is None: @@ -2513,7 +2512,7 @@ def show_transmitted_probe( **kwargs, ) - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) class ObjectNDProbeMixedMethodsMixin: diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index ae4d31913..0b3bd9435 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -232,7 +232,7 @@ def preprocess( crop_patterns: bool = False, main_tilt_axis: str = "vertical", device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -754,7 +754,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -806,7 +806,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -1263,6 +1263,6 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py index d45f9b69f..14db5dc5c 100644 --- a/py4DSTEM/process/phase/ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -844,4 +844,4 @@ def show_uncertainty_visualization( spec.tight_layout(fig) - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 66ee248e7..90a2472f2 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -205,7 +205,7 @@ def preprocess( object_fov_mask: np.ndarray = None, crop_patterns: bool = False, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, max_batch_size: int = None, **kwargs, ): @@ -560,7 +560,7 @@ def preprocess( fig.tight_layout() self._preprocessed = True - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -613,7 +613,7 @@ def reconstruct( progress_bar: bool = True, reset: bool = None, device: str = None, - clear_fft_cache: bool = True, + clear_fft_cache: bool = None, ): """ Ptychographic reconstruction main method. @@ -947,6 +947,6 @@ def reconstruct( if not use_projection_scheme: self._exit_waves = None - self.clear_device_mem(device, self._clear_fft_cache) + self.clear_device_mem(self._device, self._clear_fft_cache) return self From 293dc2048c71558ba0ea54aa4811fd8a4c9bbd0f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 17:34:41 -0800 Subject: [PATCH 095/128] adding FFT-based DCT-II implementations, fixing cupy10.6 bug Former-commit-id: 1b8ad3aace6737ab04b8fb5b484674e138777498 --- py4DSTEM/process/phase/utils.py | 130 +++++++++++++++++++++++++++++--- 1 file changed, 121 insertions(+), 9 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 32d20e812..70e1841cf 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -9,8 +9,6 @@ try: import cupy as cp - from cupyx.scipy.fft import dctn as dctn_cp - from cupyx.scipy.fft import idctn as idctn_cp from cupyx.scipy.ndimage import zoom as zoom_cp get_array_module = cp.get_array_module @@ -1578,6 +1576,121 @@ def aberrations_basis_function( return aberrations_basis, aberrations_mn +def interleave_ndarray_symmetrically(array_nd, axis, xp=np): + """[a,b,c,d,e,f] -> [a,c,e,f,d,b]""" + array_shape = np.array(array_nd.shape) + d = array_nd.ndim + n = array_shape[axis] + + array = xp.empty_like(array_nd) + array[array_slice(axis, d, None, (n - 1) // 2 + 1)] = array_nd[ + array_slice(axis, d, None, None, 2) + ] + + if n % 2: # odd + array[array_slice(axis, d, (n - 1) // 2 + 1, None)] = array_nd[ + array_slice(axis, d, -2, None, -2) + ] + else: # even + array[array_slice(axis, d, (n - 1) // 2 + 1, None)] = array_nd[ + array_slice(axis, d, None, None, -2) + ] + + return array + + +def return_exp_factors(size, ndim, axis): + none_axes = [None] * ndim + none_axes[axis] = slice(None) + exp_factors = 2 * np.exp(-1j * np.pi * np.arange(size) / (2 * size)) + return exp_factors[tuple(none_axes)] + + +def dct_II_using_FFT_base(array_nd, xp=np): + """FFT-based DCT-II""" + d = array_nd.ndim + + for axis in range(d): + n = array_nd.shape[axis] + interleaved_array = interleave_ndarray_symmetrically(array_nd, axis=axis, xp=xp) + exp_factors = return_exp_factors(n, d, axis) + interleaved_array = xp.fft.fft(interleaved_array, axis=axis) + interleaved_array *= exp_factors + array_nd = interleaved_array.real + + return array_nd + + +def dct_II_using_FFT(array_nd, xp=np): + if xp.iscomplexobj(array_nd): + real = dct_II_using_FFT_base(array_nd.real, xp=xp) + imag = dct_II_using_FFT_base(array_nd.imag, xp=xp) + return real + 1j * imag + else: + return dct_II_using_FFT_base(array_nd, xp=xp) + + +def interleave_ndarray_symmetrically_inverse(array_nd, axis, xp=np): + """[a,c,e,f,d,b] -> [a,b,c,d,e,f]""" + array_shape = np.array(array_nd.shape) + d = array_nd.ndim + n = array_shape[axis] + + array = xp.empty_like(array_nd) + array[array_slice(axis, d, None, None, 2)] = array_nd[ + array_slice(axis, d, None, (n - 1) // 2 + 1) + ] + + if n % 2: # odd + array[array_slice(axis, d, -2, None, -2)] = array_nd[ + array_slice(axis, d, (n - 1) // 2 + 1, None) + ] + else: # even + array[array_slice(axis, d, None, None, -2)] = array_nd[ + array_slice(axis, d, (n - 1) // 2 + 1, None) + ] + + return array + + +def return_exp_factors_inverse(size, ndim, axis): + none_axes = [None] * ndim + none_axes[axis] = slice(None) + exp_factors = np.exp(1j * np.pi * np.arange(size) / (2 * size)) / 2 + return exp_factors[tuple(none_axes)] + + +def idct_II_using_FFT_base(array_nd, xp=np): + """FFT-based IDCT-II""" + d = array_nd.ndim + + for axis in range(d): + n = array_nd.shape[axis] + reversed_array = xp.roll( + array_nd[array_slice(axis, d, None, None, -1)], 1, axis=axis + ) # C(N-k) + reversed_array[array_slice(axis, d, 0, 1)] = 0 # set C(N) = 0 + + interleaved_array = array_nd - 1j * reversed_array + exp_factors = return_exp_factors_inverse(n, d, axis) + interleaved_array *= exp_factors + + array_nd = xp.fft.ifft(interleaved_array, axis=axis).real + array_nd = interleave_ndarray_symmetrically_inverse(array_nd, axis=axis, xp=xp) + + return array_nd + + +def idct_II_using_FFT(array_nd, xp=np): + """FFT-based IDCT-II""" + if xp.iscomplexobj(array_nd): + real = idct_II_using_FFT_base(array_nd.real, xp=xp) + imag = idct_II_using_FFT_base(array_nd.imag, xp=xp) + return real + 1j * imag + else: + return idct_II_using_FFT_base(array_nd, xp=xp) + + def preconditioned_laplacian_neumann_2D(shape, xp=np): """DCT eigenvalues""" n, m = shape @@ -1596,15 +1709,14 @@ def preconditioned_poisson_solver_neumann_2D(rhs, gauge=None, xp=np): gauge = xp.mean(rhs) if xp is np: - dctn_xp = dctn - idctn_xp = idctn + fft_rhs = dctn(rhs, type=2) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idctn(fft_rhs / op, type=2).real else: - dctn_xp = dctn_cp - idctn_xp = idctn_cp + fft_rhs = dct_II_using_FFT(rhs, xp) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idct_II_using_FFT(fft_rhs / op, xp) - fft_rhs = dctn_xp(rhs, type=2) - fft_rhs[0, 0] = gauge # gauge invariance - sol = idctn_xp(fft_rhs / op, type=2).real return sol From cfc542c03d24c4ab8d2145b193cc5d2676446970 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 17:39:44 -0800 Subject: [PATCH 096/128] small numpy bug Former-commit-id: 6f496829d1d544e2874989194e5b6c83c9e684cb --- py4DSTEM/process/phase/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 70e1841cf..b9fc4f8c9 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1599,10 +1599,10 @@ def interleave_ndarray_symmetrically(array_nd, axis, xp=np): return array -def return_exp_factors(size, ndim, axis): +def return_exp_factors(size, ndim, axis, xp=np): none_axes = [None] * ndim none_axes[axis] = slice(None) - exp_factors = 2 * np.exp(-1j * np.pi * np.arange(size) / (2 * size)) + exp_factors = 2 * xp.exp(-1j * np.pi * xp.arange(size) / (2 * size)) return exp_factors[tuple(none_axes)] @@ -1613,7 +1613,7 @@ def dct_II_using_FFT_base(array_nd, xp=np): for axis in range(d): n = array_nd.shape[axis] interleaved_array = interleave_ndarray_symmetrically(array_nd, axis=axis, xp=xp) - exp_factors = return_exp_factors(n, d, axis) + exp_factors = return_exp_factors(n, d, axis, xp) interleaved_array = xp.fft.fft(interleaved_array, axis=axis) interleaved_array *= exp_factors array_nd = interleaved_array.real @@ -1653,10 +1653,10 @@ def interleave_ndarray_symmetrically_inverse(array_nd, axis, xp=np): return array -def return_exp_factors_inverse(size, ndim, axis): +def return_exp_factors_inverse(size, ndim, axis, xp=np): none_axes = [None] * ndim none_axes[axis] = slice(None) - exp_factors = np.exp(1j * np.pi * np.arange(size) / (2 * size)) / 2 + exp_factors = xp.exp(1j * np.pi * xp.arange(size) / (2 * size)) / 2 return exp_factors[tuple(none_axes)] @@ -1672,7 +1672,7 @@ def idct_II_using_FFT_base(array_nd, xp=np): reversed_array[array_slice(axis, d, 0, 1)] = 0 # set C(N) = 0 interleaved_array = array_nd - 1j * reversed_array - exp_factors = return_exp_factors_inverse(n, d, axis) + exp_factors = return_exp_factors_inverse(n, d, axis, xp) interleaved_array *= exp_factors array_nd = xp.fft.ifft(interleaved_array, axis=axis).real From 40421b8e24d82bcf8f974827eedbfff5bc52904d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 17:44:30 -0800 Subject: [PATCH 097/128] cleaned up cupy 12 feature guarding using try-except Former-commit-id: 13f0f5c96b02e85ab5bbd344c4b0f7aa0c3f019e --- .../phase/magnetic_ptychographic_tomography.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 1fccee78b..66553ed49 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1591,13 +1591,18 @@ def _rotate_zxy_volume_vector(self, current_object, rot_matrix): xp = self._xp swap_zxy_to_xyz = self._swap_zxy_to_xyz - if xp is np or int(xp.__version__.split(".")[0]) < 12: + if xp is np: from scipy.interpolate import RegularGridInterpolator - xp = np # ensure np is enforced for cupy < 12 current_object = self._asnumpy(current_object) else: - from cupyx.scipy.interpolate import RegularGridInterpolator + try: + from cupyx.scipy.interpolate import RegularGridInterpolator + except ModuleNotFoundError: + from scipy.interpolate import RegularGridInterpolator + + xp = np # force xp to np for cupy <12.0 + current_object = self._asnumpy(current_object) _, nz, nx, ny = current_object.shape From eec399294c919c547ab7607df26f412f6b11d64f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 17:56:45 -0800 Subject: [PATCH 098/128] restored scikit functionality Former-commit-id: c25a58ebcbd4f122357061068dd25252e3c5d973 --- py4DSTEM/process/phase/utils.py | 36 +++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index b9fc4f8c9..d36442722 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -22,6 +22,7 @@ def get_array_module(*args): from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from skimage.restoration import unwrap_phase # fmt: off @@ -1755,24 +1756,47 @@ def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np return unwrapped_array +def unwrap_phase_2d_skimage(array, corner_centered=True, xp=np): + if xp is np: + array = array.astype(np.float64) + unwrapped_array = unwrap_phase(array, wrap_around=corner_centered).astype( + xp.float32 + ) + else: + array = xp.asnumpy(array).astype(np.float64) + unwrapped_array = unwrap_phase(array, wrap_around=corner_centered) + unwrapped_array = xp.asarray(unwrapped_array).astype(xp.float32) + + return unwrapped_array + + def fit_aberration_surface( complex_probe, probe_sampling, energy, max_angular_order, max_radial_order, + use_scikit_image, xp=np, ): """ """ probe_amp = xp.abs(complex_probe) probe_angle = -xp.angle(complex_probe) - unwrapped_angle = unwrap_phase_2d( - probe_angle, - weights=probe_amp, - corner_centered=True, - xp=xp, - ) + if use_scikit_image: + unwrapped_angle = unwrap_phase_2d( + probe_angle, + corner_centered=True, + xp=xp, + ) + + else: + unwrapped_angle = unwrap_phase_2d( + probe_angle, + weights=probe_amp, + corner_centered=True, + xp=xp, + ) raveled_basis, _ = aberrations_basis_function( complex_probe.shape, From 9f4917008bb525f9fa0fda93c637f5e4595439e5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 18:13:45 -0800 Subject: [PATCH 099/128] scikit-image by default, poisson as flag Former-commit-id: fbc8436496a396d00a35a51098ef196f82a9388e --- .../process/phase/magnetic_ptychographic_tomography.py | 7 +++++++ py4DSTEM/process/phase/magnetic_ptychography.py | 7 +++++++ .../process/phase/mixedstate_multislice_ptychography.py | 6 ++++++ py4DSTEM/process/phase/mixedstate_ptychography.py | 6 ++++++ py4DSTEM/process/phase/multislice_ptychography.py | 6 ++++++ py4DSTEM/process/phase/ptychographic_constraints.py | 6 ++++++ py4DSTEM/process/phase/ptychographic_tomography.py | 7 +++++++ py4DSTEM/process/phase/singleslice_ptychography.py | 6 ++++++ py4DSTEM/process/phase/utils.py | 2 +- 9 files changed, 52 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 66553ed49..156552cc2 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -869,6 +869,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass_e: float = None, q_lowpass_m: float = None, @@ -963,6 +964,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1267,6 +1272,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, ) @@ -1306,6 +1312,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 5a382406a..d03e85632 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1123,6 +1123,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass_e: float = None, q_lowpass_m: float = None, @@ -1220,6 +1221,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass_e: float @@ -1502,6 +1507,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, ) @@ -1541,6 +1547,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index fd302bc35..a714c3713 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -717,6 +717,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -812,6 +813,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1041,6 +1046,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index b49dc0c34..eb4b66617 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -622,6 +622,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -714,6 +715,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -931,6 +936,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 3624db7ac..b75fb60f0 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -690,6 +690,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -787,6 +788,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1020,6 +1025,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index eb1b69951..425c8b0ad 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1038,6 +1038,7 @@ def _probe_aberration_fitting_constraint( max_angular_order, max_radial_order, remove_initial_probe_aberrations, + use_scikit_image, ): """ Ptychographic probe smoothing constraint. @@ -1075,6 +1076,7 @@ def _probe_aberration_fitting_constraint( energy, max_angular_order, max_radial_order, + use_scikit_image, xp=xp, ) @@ -1094,6 +1096,7 @@ def _probe_constraints( fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, fix_probe_aperture, initial_probe_aperture, constrain_probe_fourier_amplitude, @@ -1117,6 +1120,7 @@ def _probe_constraints( fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, ) # Fourier amplitude (aperture) constraints @@ -1219,6 +1223,7 @@ def _probe_constraints( fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, fix_probe_aperture, initial_probe_aperture, constrain_probe_fourier_amplitude, @@ -1244,6 +1249,7 @@ def _probe_constraints( fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, ) # Fourier amplitude (aperture) constraints diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 0b3bd9435..ff0eb785c 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -791,6 +791,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -881,6 +882,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1153,6 +1158,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, ) @@ -1192,6 +1198,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 90a2472f2..6082a9b42 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -598,6 +598,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -690,6 +691,10 @@ def reconstruct( Max radial order of probe aberrations basis functions fit_probe_aberrations_remove_initial: bool If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -907,6 +912,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, fix_probe_aperture=a0 < fix_probe_aperture_iter, initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index d36442722..a25b7acd3 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1784,7 +1784,7 @@ def fit_aberration_surface( probe_angle = -xp.angle(complex_probe) if use_scikit_image: - unwrapped_angle = unwrap_phase_2d( + unwrapped_angle = unwrap_phase_2d_skimage( probe_angle, corner_centered=True, xp=xp, From 903207849f5857acf3ca6dadcebe0d8af646cbd2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 12 Jan 2024 20:47:23 -0800 Subject: [PATCH 100/128] various fixes discovered while making testing notebook Former-commit-id: bb5446528c91840464ec32dddb88bbb8d6d5ea04 --- .../mixedstate_multislice_ptychography.py | 2 +- .../process/phase/ptychographic_methods.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index a714c3713..b3bb0093a 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -397,7 +397,6 @@ def preprocess( vacuum_probe_intensity=self._vacuum_probe_intensity, dp_mask=self._dp_mask, com_shifts=force_com_shifts, - vectorized_calculation=vectorized_com_calculation, ) # calibrations @@ -428,6 +427,7 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, ) # estimate rotation / transpose diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index e5d90e3e5..d1f6fc96c 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1985,9 +1985,16 @@ def _position_correction( xp = self._xp asnumpy = self._asnumpy + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + # unperturbed overlap_fft = xp.fft.fft2(overlap) overlap_fft_conj = xp.conj(overlap_fft) + estimated_intensity = self._return_farfield_amplitudes(overlap_fft) ** 2 measured_intensity = amplitudes**2 @@ -2013,6 +2020,15 @@ def _position_correction( shifted_probes, ) + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap_dx = vectorized_bilinear_resample( + overlap_dx, output_size=amplitudes.shape[-2:], xp=xp + ) + overlap_dy = vectorized_bilinear_resample( + overlap_dy, output_size=amplitudes.shape[-2:], xp=xp + ) + # partial intensities overlap_dx_fft = overlap_fft - xp.fft.fft2(overlap_dx) overlap_dy_fft = overlap_fft - xp.fft.fft2(overlap_dy) @@ -2105,6 +2121,13 @@ def _return_self_consistency_errors( vectorized_patch_indices_col, shifted_probes, ) + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes_device.shape[-2:], xp=xp + ) + fourier_overlap = xp.fft.fft2(overlap) farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) From ecb863c244688bb3da86b3d12dba21aa270a4819 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 13 Jan 2024 08:57:06 -0800 Subject: [PATCH 101/128] position update bug fix Former-commit-id: 7dc1d980cf64f2088efca9f35f0e7652717fdf44 --- py4DSTEM/process/phase/ptychographic_methods.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index d1f6fc96c..a1afe275d 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1984,6 +1984,7 @@ def _position_correction( xp = self._xp asnumpy = self._asnumpy + xp_storage = self._xp_storage # resample to match data, note: this needs to happen in real-space if self._resample_exit_waves: @@ -2076,7 +2077,10 @@ def _position_correction( outlier_ind = dsts > max_position_total_distance positions_update[outlier_ind] = 0 - current_positions -= asnumpy(positions_update) + if xp_storage == np: + current_positions -= asnumpy(positions_update) + else: + current_positions -= positions_update return current_positions From 116b7f6a3ce8c8e5582540e8a09fafcff1cac46d Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 13 Jan 2024 09:21:18 -0800 Subject: [PATCH 102/128] syntax update for bug fix Former-commit-id: e67109afcf4f93538f5b9a80d92ffa3ced2454db --- py4DSTEM/process/phase/ptychographic_methods.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index a1afe275d..68a2d8334 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1984,7 +1984,7 @@ def _position_correction( xp = self._xp asnumpy = self._asnumpy - xp_storage = self._xp_storage + storage = self._storage # resample to match data, note: this needs to happen in real-space if self._resample_exit_waves: @@ -2077,10 +2077,7 @@ def _position_correction( outlier_ind = dsts > max_position_total_distance positions_update[outlier_ind] = 0 - if xp_storage == np: - current_positions -= asnumpy(positions_update) - else: - current_positions -= positions_update + current_positions -= copy_to_device(positions_update, storage) return current_positions From 141eb37ec62b8c148fedfbaecd9a35f67cd9edf2 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 13 Jan 2024 12:26:48 -0800 Subject: [PATCH 103/128] complex plotting grid search Former-commit-id: 00d49554ca0979fb48c2acaddda1056ea0d2f1f8 --- py4DSTEM/process/phase/parameter_optimize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index ff9982b44..554ceb984 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -227,7 +227,10 @@ def evaluation_callback(ptycho): row_index, col_index = np.unravel_index(index, (nrows, ncols)) ax = fig.add_subplot(spec[row_index, col_index]) - ax.imshow(res[0], cmap=cmap) + if np.iscomplexobj(res[0]): + ax.imshow(np.angle(res[0]), cmap=cmap) + else: + ax.imshow(res[0], cmap=cmap) title_substrings = [ f"{param.name}: {val}" From a8fdc435fa89d1fc9b3ac6d555a1d8e666784b8e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 13 Jan 2024 12:48:43 -0800 Subject: [PATCH 104/128] multislice grid search plotting Former-commit-id: 62cb64433763ee633218c5813159a288bdd4fa0f --- py4DSTEM/process/phase/parameter_optimize.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 554ceb984..aa369872d 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -178,7 +178,10 @@ def grid_search( def evaluation_callback(ptycho): if plot_reconstructed_objects or return_reconstructed_objects: pbar.update(1) - return (ptycho.object_cropped, error_metric(ptycho)) + return ( + ptycho._return_projected_cropped_potential(), + error_metric(ptycho), + ) else: pbar.update(1) error_metric(ptycho) @@ -227,10 +230,7 @@ def evaluation_callback(ptycho): row_index, col_index = np.unravel_index(index, (nrows, ncols)) ax = fig.add_subplot(spec[row_index, col_index]) - if np.iscomplexobj(res[0]): - ax.imshow(np.angle(res[0]), cmap=cmap) - else: - ax.imshow(res[0], cmap=cmap) + ax.imshow(res[0], cmap=cmap) title_substrings = [ f"{param.name}: {val}" From 48675eb008546ae425b6928d3b3c047b9978e78d Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 15 Jan 2024 09:22:09 -0800 Subject: [PATCH 105/128] read/write bug fix Former-commit-id: 2deb45a1cc3519b23799565c1cdbe4f1651f479c --- py4DSTEM/process/phase/phase_base_class.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 103d8f467..2554aab47 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1540,15 +1540,7 @@ def to_h5(self, group): # reconstruction metadata is_stack = self._save_iterations and hasattr(self, "object_iterations") - if is_stack: - num_iterations = len(self.object_iterations) - iterations = list(range(0, num_iterations, self._save_iterations_frequency)) - if num_iterations - 1 not in iterations: - iterations.append(num_iterations - 1) - - error = [self.error_iterations[i] for i in iterations] - else: - error = getattr(self, "error", 0.0) + error = self.error_iterations self.metadata = Metadata( name="reconstruction_metadata", @@ -1573,6 +1565,8 @@ def to_h5(self, group): self._probe_emd = Array(name="reconstruction_probe", data=asnumpy(self._probe)) if is_stack: + num_iterations = len(self.object_iterations) + iterations = list(range(0, num_iterations, self._save_iterations_frequency)) iterations_labels = [f"iteration_{i:03}" for i in iterations] # object From c0f5e969026674b47c06a8c2fe370bb62fe102a9 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 15 Jan 2024 16:58:33 -0800 Subject: [PATCH 106/128] show hanning window fix Former-commit-id: 770cd013789681d68f14b9ddfdc8f6ceea118ca6 --- py4DSTEM/visualize/show.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 8462eec7d..b1b24a010 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -76,6 +76,7 @@ def show( theta=None, title=None, show_fft=False, + apply_hanning_window=True, show_cbar=False, **kwargs, ): @@ -305,6 +306,8 @@ def show( which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. show_fft (bool): if True, plots 2D-fft of array + apply_hanning_window (bool) + If True, a 2D Hann window is applied to the before FFT show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() @@ -369,9 +372,12 @@ def show( from py4DSTEM.visualize import show if show_fft: - n0 = ar.shape - w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] - ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + if apply_hanning_window: + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + else: + ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) for a0 in range(num_images): im = show( ar[a0], @@ -451,7 +457,12 @@ def show( # Otherwise, plot one image if show_fft: if combine_images is False: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + if apply_hanning_window: + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + else: + ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) # get image from a masked array if mask is not None: From 2845c56808b152c06469d051aa8c69d61f6bb30e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 15 Jan 2024 17:00:12 -0800 Subject: [PATCH 107/128] typo fix Former-commit-id: 58c2e4379706a9b6da2d8c8bd448861fdb80c5f9 --- py4DSTEM/visualize/show.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index b1b24a010..7430992e0 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -307,7 +307,7 @@ def show( for scalebar. If False, no scalebar is added. show_fft (bool): if True, plots 2D-fft of array apply_hanning_window (bool) - If True, a 2D Hann window is applied to the before FFT + If True, a 2D Hann window is applied to the array before applying the FFT show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() From 8028a82378021d44011bbc6635ccb3192811978d Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 16 Jan 2024 15:40:44 -0800 Subject: [PATCH 108/128] mixedstate probe fourier constraint bug fix Former-commit-id: 22cb993b9a92c7ee11b4900d58deb493a59af9ca --- py4DSTEM/process/phase/ptychographic_constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 425c8b0ad..281d6ef8a 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1254,8 +1254,8 @@ def _probe_constraints( # Fourier amplitude (aperture) constraints if fix_probe_aperture: - current_probe[0] = self._probe_aperture_constraint( - current_probe[0], + current_probe = self._probe_aperture_constraint( + current_probe, initial_probe_aperture, ) elif constrain_probe_fourier_amplitude: From 9a6be677b16ff1cb0afe2bd749a16107a83b6ce6 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 17 Jan 2024 09:26:35 -0800 Subject: [PATCH 109/128] constrain first aperture only Former-commit-id: 8a055e17df1ff06f6a3353608805ce72b27ed0ce --- py4DSTEM/process/phase/ptychographic_constraints.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 281d6ef8a..2d76cd1cb 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1254,9 +1254,9 @@ def _probe_constraints( # Fourier amplitude (aperture) constraints if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, + current_probe[0] = self._probe_aperture_constraint( + current_probe[0], + initial_probe_aperture[0], ) elif constrain_probe_fourier_amplitude: current_probe[0] = self._probe_fourier_amplitude_constraint( From bebf112f85b4c22912cbb44ac69d8d82ea43f875 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 19 Jan 2024 10:10:09 -0800 Subject: [PATCH 110/128] parallax verbosity to True Former-commit-id: ee11e9bb8d8d11860a26888ec9c08f0e434c2bb4 --- py4DSTEM/process/phase/parallax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index a0eea59df..650d3f374 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -79,7 +79,7 @@ def __init__( self, energy: float, datacube: DataCube = None, - verbose: bool = False, + verbose: bool = True, object_padding_px: Tuple[int, int] = (32, 32), device: str = "cpu", storage: str = None, From d6a4d52e868627b81561d81dd8eb0621edc22519 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 19 Jan 2024 10:35:49 -0800 Subject: [PATCH 111/128] ms butterworth bug Former-commit-id: f760e283840f293ce4e73cd4dec09636a4c6cb78 --- py4DSTEM/process/phase/ptychographic_constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 2d76cd1cb..82234afb7 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -178,8 +178,8 @@ def _object_butterworth_constraint( Constrained object estimate """ xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[-1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[-2], self.sampling[1]) qya, qxa = xp.meshgrid(qy, qx) qra = xp.sqrt(qxa**2 + qya**2) From bd87b883916f01f57a88316d131e13f3fa9a49a2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 19 Jan 2024 10:40:30 -0800 Subject: [PATCH 112/128] silly George Former-commit-id: b06e64feda12fd1af22f3b6dd68a1552fdebf36b --- py4DSTEM/process/phase/ptychographic_constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 82234afb7..8a10b4df2 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -178,8 +178,8 @@ def _object_butterworth_constraint( Constrained object estimate """ xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[-1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[-2], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1]) qya, qxa = xp.meshgrid(qy, qx) qra = xp.sqrt(qxa**2 + qya**2) From bb8b2b36f3fa82bea384c17cd09c2acd26984a71 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 19 Jan 2024 11:22:37 -0800 Subject: [PATCH 113/128] cleaning up warnings Former-commit-id: f6b1c5487aad87f8d776480acee9622c99cf8835 --- py4DSTEM/process/phase/dpc.py | 5 +- .../magnetic_ptychographic_tomography.py | 3 +- .../process/phase/magnetic_ptychography.py | 8 +++- .../mixedstate_multislice_ptychography.py | 3 +- .../process/phase/mixedstate_ptychography.py | 3 +- .../process/phase/multislice_ptychography.py | 3 +- py4DSTEM/process/phase/parallax.py | 3 +- py4DSTEM/process/phase/phase_base_class.py | 47 ++++++++++++------- .../phase/ptychographic_constraints.py | 3 ++ .../process/phase/ptychographic_methods.py | 4 +- .../process/phase/ptychographic_tomography.py | 3 +- .../phase/ptychographic_visualizations.py | 3 -- .../process/phase/singleslice_ptychography.py | 3 +- 13 files changed, 57 insertions(+), 34 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index 4997f733a..af004cc3c 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -3,6 +3,7 @@ namely DPC. """ +import sys import warnings from typing import Sequence, Tuple, Union @@ -22,7 +23,7 @@ from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.visualize.vis_special import return_scaled_histogram_ordering -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class DPC(PhaseReconstruction): @@ -773,8 +774,6 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - if self._verbose: - print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 156552cc2..19994ccab 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -3,6 +3,7 @@ namely magnetic ptychographic tomography. """ +import sys import warnings from typing import Mapping, Sequence, Tuple @@ -51,7 +52,7 @@ project_vector_field_divergence_periodic_3D, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class MagneticPtychographicTomography( diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index d03e85632..acfd852bc 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -3,6 +3,7 @@ namely magnetic ptychography. """ +import sys import warnings from typing import Mapping, Sequence, Tuple @@ -45,7 +46,7 @@ polar_symbols, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class MagneticPtychography( @@ -347,7 +348,10 @@ def preprocess( ) if self._verbose: - print(magnetic_contribution_msg) + warnings.warn( + magnetic_contribution_msg, + UserWarning, + ) if len(self._datacube) != self._num_measurements: raise ValueError( diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index b3bb0093a..27beda1fa 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -3,6 +3,7 @@ namely multislice ptychography. """ +import sys import warnings from typing import Mapping, Sequence, Tuple, Union @@ -45,7 +46,7 @@ polar_symbols, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class MixedstateMultislicePtychography( diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index eb4b66617..4b4fa00ac 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -3,6 +3,7 @@ namely mixed-state ptychography. """ +import sys import warnings from typing import Mapping, Tuple @@ -42,7 +43,7 @@ polar_symbols, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class MixedstatePtychography( diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index b75fb60f0..7b2cbf27d 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -3,6 +3,7 @@ namely multislice ptychography. """ +import sys import warnings from typing import Mapping, Sequence, Tuple, Union @@ -42,7 +43,7 @@ polar_symbols, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class MultislicePtychography( diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 650d3f374..159393382 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -3,6 +3,7 @@ images by aligning each virtual BF image. """ +import sys import warnings from typing import Tuple @@ -36,7 +37,7 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) _aberration_names = { (1, 0): "C1 ", diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 2554aab47..67101ce62 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -2,6 +2,7 @@ Module for reconstructing phase objects from 4DSTEM datasets using iterative methods. """ +import sys import warnings import matplotlib.pyplot as plt @@ -32,7 +33,7 @@ get_shifted_ar, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class PhaseReconstruction(Custom): @@ -878,9 +879,10 @@ def _solve_for_center_of_mass_relative_rotation( if verbose: if _rotation_best_transpose: - print("Diffraction intensities should be transposed.") - else: - print("No need to transpose diffraction intensities.") + warnings.warn( + "Diffraction intensities should be transposed.", + UserWarning, + ) else: # Rotation unknown @@ -985,7 +987,10 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_rad = rotation_angles_rad[ind_min] if verbose: - print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) + warnings.warn( + f"Best fit rotation = {rotation_best_deg:.0f} degrees.", + UserWarning, + ) if plot_rotation: figsize = kwargs.get("figsize", (8, 2)) @@ -1139,11 +1144,15 @@ def _solve_for_center_of_mass_relative_rotation( # Print summary if verbose: - print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) + warnings.warn( + f"Best fit rotation = {rotation_best_deg:.0f} degrees.", + UserWarning, + ) if _rotation_best_transpose: - print("Diffraction intensities should be transposed.") - else: - print("No need to transpose diffraction intensities.") + warnings.warn( + "Diffraction intensities should be transposed.", + UserWarning, + ) # Plot Curl/Div rotation if plot_rotation: @@ -2057,41 +2066,45 @@ def _report_reconstruction_summary( ) ) else: - print( + warnings.warn( ( first_line + f"with the {reconstruction_method} algorithm, " f"with normalization_min: {normalization_min} and step _size: {step_size}, " f"in batches of max {max_batch_size} measurements." - ) + ), + UserWarning, ) else: # named projection set method if reconstruction_parameter is not None: - print( + warnings.warn( ( first_line + f"with the {reconstruction_method} algorithm, " f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) + ), + UserWarning, ) # generalized projections (or the even more rare charge-flipping) elif projection_a is not None: - print( + warnings.warn( ( first_line + f"with the {reconstruction_method} algorithm, " f"with normalization_min: {normalization_min} and (a,b,c): " f"{projection_a, projection_b, projection_c}." - ) + ), + UserWarning, ) # gradient descent else: - print( + warnings.warn( ( first_line + f"with the {reconstruction_method} algorithm, " f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) + ), + UserWarning, ) def _constraints( diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 8a10b4df2..11c3ed56f 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1,3 +1,4 @@ +import sys import warnings import numpy as np @@ -20,6 +21,8 @@ os.environ["CUPY_PYLOPS"] = "0" import pylops # this must follow the exception +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) + class ObjectNDConstraintsMixin: """ diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 68a2d8334..e19290482 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1,3 +1,4 @@ +import sys import warnings from typing import Sequence, Tuple @@ -25,7 +26,7 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class ObjectNDMethodsMixin: @@ -1983,7 +1984,6 @@ def _position_correction( """ xp = self._xp - asnumpy = self._asnumpy storage = self._storage # resample to match data, note: this needs to happen in real-space diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index ff0eb785c..4823bdae0 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -3,6 +3,7 @@ namely joint ptychographic tomography. """ +import sys import warnings from typing import Mapping, Sequence, Tuple @@ -45,7 +46,7 @@ polar_symbols, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class PtychographicTomography( diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py index 14db5dc5c..58dd224cf 100644 --- a/py4DSTEM/process/phase/ptychographic_visualizations.py +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -1,4 +1,3 @@ -import warnings from typing import Tuple import matplotlib.pyplot as plt @@ -17,8 +16,6 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.simplefilter(action="always", category=UserWarning) - class VisualizationsMixin: """ diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 6082a9b42..49da1a496 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -3,6 +3,7 @@ namely (single-slice) ptychography. """ +import sys import warnings from typing import Mapping, Tuple @@ -39,7 +40,7 @@ polar_symbols, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) class SingleslicePtychography( From a2844b787bcf3e38f1f38587ae0bfc896f8ec043 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 19 Jan 2024 13:16:33 -0800 Subject: [PATCH 114/128] simplifying warnings, restructuring single slice regularization flags Former-commit-id: bdf2fe3b347f6d188ca37fd4d70fac4c46047193 --- py4DSTEM/process/phase/dpc.py | 3 - .../magnetic_ptychographic_tomography.py | 3 - .../process/phase/magnetic_ptychography.py | 3 - .../mixedstate_multislice_ptychography.py | 4 - .../process/phase/mixedstate_ptychography.py | 4 - .../process/phase/multislice_ptychography.py | 4 - py4DSTEM/process/phase/parallax.py | 3 - py4DSTEM/process/phase/phase_base_class.py | 1 + .../phase/ptychographic_constraints.py | 3 - .../process/phase/ptychographic_methods.py | 3 - .../process/phase/ptychographic_tomography.py | 3 - .../process/phase/singleslice_ptychography.py | 109 ++++++++---------- 12 files changed, 51 insertions(+), 92 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index af004cc3c..fe9044595 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -3,7 +3,6 @@ namely DPC. """ -import sys import warnings from typing import Sequence, Tuple, Union @@ -23,8 +22,6 @@ from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.visualize.vis_special import return_scaled_histogram_ordering -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class DPC(PhaseReconstruction): """ diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 19994ccab..13ab8d846 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -3,7 +3,6 @@ namely magnetic ptychographic tomography. """ -import sys import warnings from typing import Mapping, Sequence, Tuple @@ -52,8 +51,6 @@ project_vector_field_divergence_periodic_3D, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class MagneticPtychographicTomography( VisualizationsMixin, diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index acfd852bc..75c667094 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -3,7 +3,6 @@ namely magnetic ptychography. """ -import sys import warnings from typing import Mapping, Sequence, Tuple @@ -46,8 +45,6 @@ polar_symbols, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class MagneticPtychography( VisualizationsMixin, diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 27beda1fa..28758f477 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -3,8 +3,6 @@ namely multislice ptychography. """ -import sys -import warnings from typing import Mapping, Sequence, Tuple, Union import matplotlib.pyplot as plt @@ -46,8 +44,6 @@ polar_symbols, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class MixedstateMultislicePtychography( VisualizationsMixin, diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 4b4fa00ac..bcce1ad3f 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -3,8 +3,6 @@ namely mixed-state ptychography. """ -import sys -import warnings from typing import Mapping, Tuple import matplotlib.pyplot as plt @@ -43,8 +41,6 @@ polar_symbols, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class MixedstatePtychography( VisualizationsMixin, diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 7b2cbf27d..a0c46f898 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -3,8 +3,6 @@ namely multislice ptychography. """ -import sys -import warnings from typing import Mapping, Sequence, Tuple, Union import matplotlib.pyplot as plt @@ -43,8 +41,6 @@ polar_symbols, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class MultislicePtychography( VisualizationsMixin, diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 159393382..3156c665b 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -3,7 +3,6 @@ images by aligning each virtual BF image. """ -import sys import warnings from typing import Tuple @@ -37,8 +36,6 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - _aberration_names = { (1, 0): "C1 ", (1, 2): "stig ", diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 67101ce62..0e1e5ca5b 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -34,6 +34,7 @@ ) warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) +warnings.simplefilter("always", UserWarning) class PhaseReconstruction(Custom): diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 11c3ed56f..8a10b4df2 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -1,4 +1,3 @@ -import sys import warnings import numpy as np @@ -21,8 +20,6 @@ os.environ["CUPY_PYLOPS"] = "0" import pylops # this must follow the exception -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class ObjectNDConstraintsMixin: """ diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index e19290482..5401b1393 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1,4 +1,3 @@ -import sys import warnings from typing import Sequence, Tuple @@ -26,8 +25,6 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class ObjectNDMethodsMixin: """ diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 4823bdae0..67b57afcb 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -3,7 +3,6 @@ namely joint ptychographic tomography. """ -import sys import warnings from typing import Mapping, Sequence, Tuple @@ -46,8 +45,6 @@ polar_symbols, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class PtychographicTomography( VisualizationsMixin, diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 49da1a496..6a2e6253c 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -3,8 +3,6 @@ namely (single-slice) ptychography. """ -import sys -import warnings from typing import Mapping, Tuple import matplotlib.pyplot as plt @@ -40,8 +38,6 @@ polar_symbols, ) -warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) - class SingleslicePtychography( VisualizationsMixin, @@ -567,7 +563,7 @@ def preprocess( def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -578,33 +574,33 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.5, - pure_phase_object_iter: int = 0, + pure_phase_object: bool = False, fix_probe_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, + tv_denoise: bool = True, tv_denoise_weight: float = None, tv_denoise_inner_iter: float = 40, object_positivity: bool = True, @@ -622,8 +618,8 @@ def reconstruct( Parameters -------- - max_iter: int, optional - Maximum number of iterations to run + num_iter: int, optional + Number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: "generalized-projections", @@ -650,28 +646,28 @@ def reconstruct( Probe normalization minimum as a fraction of the maximum overlap intensity positions_step_size: float, optional Positions update step size - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity + pure_phase_object: bool, optional + If True, object amplitude is set to unity fix_probe_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: bool, optional + If True, probe-positions are fixed fix_positions_com: bool, optional If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional @@ -682,10 +678,10 @@ def reconstruct( If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: int Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: int @@ -696,16 +692,16 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object + tv_denoise: bool, optional + If True, object is smoothed using TV denoising tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. tv_denoise_inner_iter: float @@ -773,7 +769,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, switch_object_iter, use_projection_scheme, reconstruction_method, @@ -799,7 +795,7 @@ def reconstruct( # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -868,11 +864,11 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) # position correction - if a0 >= fix_positions_iter: + if not fix_positions: self._positions_px[batch_indices] = self._position_correction( self._object, vectorized_patch_indices_row, @@ -898,36 +894,32 @@ def reconstruct( self._probe, self._positions_px, self._positions_px_initial, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise=tv_denoise and tv_denoise_weight is not None, tv_denoise_weight=tv_denoise_weight, tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, @@ -935,8 +927,7 @@ def reconstruct( object_mask=self._object_fov_mask_inverse if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", + pure_phase_object=pure_phase_object and self._object_type == "complex", ) self.error_iterations.append(error.item()) From 82644ede4833a34ca26859acd5d68dff7da2859d Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 20 Jan 2024 09:12:44 -0800 Subject: [PATCH 115/128] one more read write bug fix Former-commit-id: 108fd5fa81225c16b5da986b97a6245045de6a54 --- py4DSTEM/process/phase/phase_base_class.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 0e1e5ca5b..ca1b7c332 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1696,13 +1696,12 @@ def _populate_instance(self, group): self._exit_waves = None # Check if stack - if hasattr(error, "__len__"): + if "_object_iterations_emd" in dict_data.keys(): self.object_iterations = list(dict_data["_object_iterations_emd"].data) self.probe_iterations = list(dict_data["_probe_iterations_emd"].data) - self.error_iterations = error - self.error = error[-1] - else: - self.error = error + + self.error_iterations = error + self.error = error[-1] # Slim preprocessing to enable visualize self._positions_px_com = xp.mean(self._positions_px, axis=0) From e18c4cbb5d03c514882c16e55622ea8ed4d806a6 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 20 Jan 2024 13:04:54 -0800 Subject: [PATCH 116/128] more flags more problems: multislice, mixed state, and mixed-multislice Former-commit-id: 5706d082df381194927355f3941af138d2744aaa --- .../mixedstate_multislice_ptychography.py | 125 ++++++++-------- .../process/phase/mixedstate_ptychography.py | 107 +++++++------- .../process/phase/multislice_ptychography.py | 133 +++++++++--------- .../process/phase/singleslice_ptychography.py | 6 +- 4 files changed, 178 insertions(+), 193 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 28758f477..907847cf3 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -682,7 +682,7 @@ def preprocess( def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -695,41 +695,41 @@ def reconstruct( positions_step_size: float = 0.9, fix_probe_com: bool = True, orthogonalize_probe: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - kz_regularization_filter_iter: int = np.inf, + kz_regularization_filter: bool = True, kz_regularization_gamma: Union[float, np.ndarray] = None, - identical_slices_iter: int = 0, + identical_slices: bool = False, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - pure_phase_object_iter: int = 0, - tv_denoise_iter_chambolle=np.inf, + pure_phase_object: bool = False, + tv_denoise_chambolle: bool = True, tv_denoise_weight_chambolle=None, tv_denoise_pad_chambolle=1, - tv_denoise_iter=np.inf, + tv_denoise: bool = True, tv_denoise_weights=None, tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, @@ -744,7 +744,7 @@ def reconstruct( Parameters -------- - max_iter: int, optional + num_iter: int, optional Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: @@ -774,24 +774,24 @@ def reconstruct( Positions update step size fix_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: bool, optional + If True, probe-positions are fixed max_position_update_distance: float, optional Maximum allowed distance for update in A max_position_total_distance: float, optional @@ -800,10 +800,10 @@ def reconstruct( If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: bool Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool @@ -814,36 +814,36 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter_iter: int, optional - Number of iterations to run using kz regularization filter + kz_regularization_filter: bool, optional + If True and kz_regularization_gamma is not None, applies kz regularization filter kz_regularization_gamma, float, optional kz regularization strength - identical_slices_iter: int, optional - Number of iterations to run using identical slices + identical_slices: bool, optional + If True, object forced to identical slices object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration fix_potential_baseline: bool If true, the potential mean outside the FOV is forced to zero at each iteration - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - tv_denoise_iter_chambolle: bool - Number of iterations with TV denoisining + pure_phase_object: bool, optional + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True and tv_denoise_weight_chambolle is not None, object is smoothed using TV denoisining tv_denoise_weight_chambolle: float weight of tv denoising constraint tv_denoise_pad_chambolle: int If not None, pads object at top and bottom with this many zeros before applying denoising tv_denoise: bool - If True, applies TV denoising on object + If True and tv_denoise_weights is not None, object is smoothed using TV denoising tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. @@ -903,7 +903,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, switch_object_iter, use_projection_scheme, reconstruction_method, @@ -929,7 +929,7 @@ def reconstruct( # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -998,11 +998,11 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) # position correction - if a0 >= fix_positions_iter: + if not fix_positions: self._positions_px[batch_indices] = self._position_correction( self._object, vectorized_patch_indices_row, @@ -1028,54 +1028,49 @@ def reconstruct( self._probe, self._positions_px, self._positions_px_initial, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture, initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, - kz_regularization_filter=a0 < kz_regularization_filter_iter + kz_regularization_filter=kz_regularization_filter and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, - identical_slices=a0 < identical_slices_iter, + identical_slices=identical_slices, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + pure_phase_object=pure_phase_object and self._object_type == "complex", + tv_denoise_chambolle=tv_denoise_chambolle and tv_denoise_weight_chambolle is not None, tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, orthogonalize_probe=orthogonalize_probe, diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index bcce1ad3f..f5c5e6cd3 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -586,7 +586,7 @@ def preprocess( def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -597,34 +597,34 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, + pure_phase_object: bool = False, fix_probe_com: bool = True, orthogonalize_probe: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, max_position_update_distance: float = None, max_position_total_distance: float = None, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, + tv_denoise: bool = True, tv_denoise_weight: float = None, tv_denoise_inner_iter: float = 40, object_positivity: bool = True, @@ -642,8 +642,8 @@ def reconstruct( Parameters -------- - max_iter: int, optional - Maximum number of iterations to run + num_iter: int, optional + Number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: "generalized-projections", @@ -670,28 +670,28 @@ def reconstruct( Probe normalization minimum as a fraction of the maximum overlap intensity positions_step_size: float, optional Positions update step size - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe: bool, optional + If True, probe is fixed fix_probe_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional + constrain_probe_fourier_amplitude: bool, optional Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: int, optional + If True, probe-positions are fixed fix_positions_com: bool, optional If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional @@ -702,10 +702,10 @@ def reconstruct( If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: bool Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool @@ -716,16 +716,16 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. tv_denoise_inner_iter: float @@ -746,9 +746,9 @@ def reconstruct( reset: bool, optional If True, previous reconstructions are ignored device: str, optional - if not none, overwrites self._device to set device preprocess will be perfomed on. + If not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional - if true, and device = 'gpu', clears the cached fft plan at the end of function calls + If true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -793,7 +793,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, switch_object_iter, use_projection_scheme, reconstruction_method, @@ -819,7 +819,7 @@ def reconstruct( # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -888,11 +888,11 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) # position correction - if a0 >= fix_positions_iter: + if not fix_positions: self._positions_px[batch_indices] = self._position_correction( self._object, vectorized_patch_indices_row, @@ -918,37 +918,33 @@ def reconstruct( self._probe, self._positions_px, self._positions_px_initial, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture, initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise=tv_denoise and tv_denoise_weight is not None, tv_denoise_weight=tv_denoise_weight, tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, @@ -956,8 +952,7 @@ def reconstruct( object_mask=self._object_fov_mask_inverse if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", + pure_phase_object=pure_phase_object and self._object_type == "complex", ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index a0c46f898..a0b247791 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -656,7 +656,7 @@ def preprocess( def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -668,41 +668,41 @@ def reconstruct( normalization_min: float = 1, positions_step_size: float = 0.9, fix_probe_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - kz_regularization_filter_iter: int = np.inf, + kz_regularization_filter: bool = True, kz_regularization_gamma: float = None, - identical_slices_iter: int = 0, + identical_slices: bool = False, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - pure_phase_object_iter: int = 0, - tv_denoise_iter_chambolle=np.inf, + pure_phase_object: bool = False, + tv_denoise_chambolle: bool = True, tv_denoise_weight_chambolle=None, tv_denoise_pad_chambolle=1, - tv_denoise_iter=np.inf, + tv_denoise: bool = True, tv_denoise_weights=None, tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, @@ -717,8 +717,8 @@ def reconstruct( Parameters -------- - max_iter: int, optional - Maximum number of iterations to run + num_iter: int, optional + Number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: "generalized-projections", @@ -747,24 +747,24 @@ def reconstruct( Positions update step size fix_probe_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vacuum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: bool, optional + If True, probe-positions are fixed fix_positions_com: bool, optional If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional @@ -775,10 +775,10 @@ def reconstruct( If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: bool Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool @@ -789,36 +789,36 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter_iter: int, optional - Number of iterations to run using kz regularization filter + kz_regularization_filter: bool, optional + If True and kz_regularization_gamma is not None, applies kz regularization filter kz_regularization_gamma, float, optional kz regularization strength - identical_slices_iter: int, optional - Number of iterations to run using identical slices + identical_slices: int, optional + If True, object forced to identical slices object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration fix_potential_baseline: bool If true, the potential mean outside the FOV is forced to zero at each iteration - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - tv_denoise_iter_chambolle: bool - Number of iterations with TV denoisining + pure_phase_object: bool, optional + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True and tv_denoise_weight_chambolle is not None, object is smoothed using TV denoisining tv_denoise_weight_chambolle: float weight of tv denoising constraint tv_denoise_pad_chambolle: int - if not None, pads object at top and bottom with this many zeros before applying denoising + If not None, pads object at top and bottom with this many zeros before applying denoising tv_denoise: bool - If True, applies TV denoising on object + If True and tv_denoise_weights is not None, object is smoothed using TV denoising tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. @@ -834,9 +834,9 @@ def reconstruct( reset: bool, optional If True, previous reconstructions are ignored device: str, optional - if not none, overwrites self._device to set device preprocess will be perfomed on. + If not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional - if true, and device = 'gpu', clears the cached fft plan at the end of function calls + If true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -882,7 +882,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, switch_object_iter, use_projection_scheme, reconstruction_method, @@ -908,7 +908,7 @@ def reconstruct( # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -977,11 +977,11 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) # position correction - if a0 >= fix_positions_iter: + if not fix_positions: self._positions_px[batch_indices] = self._position_correction( self._object, vectorized_patch_indices_row, @@ -1007,51 +1007,46 @@ def reconstruct( self._probe, self._positions_px, self._positions_px_initial, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, - kz_regularization_filter=a0 < kz_regularization_filter_iter + kz_regularization_filter=kz_regularization_filter and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma, - identical_slices=a0 < identical_slices_iter, + identical_slices=identical_slices, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + pure_phase_object=pure_phase_object and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_chambolle and tv_denoise_weight_chambolle is not None, tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, ) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 6a2e6253c..1c73e2ade 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -679,7 +679,7 @@ def reconstruct( gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A gaussian_filter: bool, optional - If True, object is smoothed using gaussian filtering + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering fit_probe_aberrations: bool, optional If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: int @@ -693,7 +693,7 @@ def reconstruct( to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. butterworth_filter: bool, optional - If True, object is smoothed using butterworth filtering + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float @@ -701,7 +701,7 @@ def reconstruct( butterworth_order: float Butterworth filter order. Smaller gives a smoother filter tv_denoise: bool, optional - If True, object is smoothed using TV denoising + If True and tv_denoise_weight is not None, object is smoothed using TV denoising tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. tv_denoise_inner_iter: float From bd72296701eefa75ca08b397ec760dae5c235f80 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 20 Jan 2024 13:11:07 -0800 Subject: [PATCH 117/128] dpc Former-commit-id: 7bc7b5548a5673d9f6cd9c8bee710a3bbfc216be --- py4DSTEM/process/phase/dpc.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index fe9044595..5a2210d59 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -645,8 +645,8 @@ def reconstruct( backtrack: bool = True, progress_bar: bool = True, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - butterworth_filter_iter: int = np.inf, + gaussian_filter: bool = True, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, @@ -675,10 +675,10 @@ def reconstruct( If True, reconstruction progress bar will be printed gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float @@ -785,10 +785,9 @@ def reconstruct( # constraints self._padded_object_phase = self._constraints( self._padded_object_phase, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, From a1fe65272a2bbe7b0764bb385d597f72f69f0287 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 20 Jan 2024 13:11:36 -0800 Subject: [PATCH 118/128] mistake in multislice Former-commit-id: d176f83ee16ac739d943a464ba5970a73afcf7bc --- py4DSTEM/process/phase/multislice_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index a0b247791..23808c535 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -1042,7 +1042,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, pure_phase_object=pure_phase_object and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_chambolle + tv_denoise_chambolle=tv_denoise_chambolle and tv_denoise_weight_chambolle is not None, tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, From 29a06a13dfc50e0394706ea42222eef15172f1ac Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 18:51:04 +0000 Subject: [PATCH 119/128] flags to magnetic ptycho Former-commit-id: bc6ad35beadd44ee038bd04bceae67daa68f9c0d --- .../process/phase/magnetic_ptychography.py | 138 ++++++++---------- 1 file changed, 64 insertions(+), 74 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 75c667094..35bbb8690 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1091,7 +1091,7 @@ def _object_constraints( def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -1102,36 +1102,36 @@ def reconstruct( step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, + pure_phase_object: bool = False, fix_probe_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, gaussian_filter_sigma_e: float = None, gaussian_filter_sigma_m: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass_e: float = None, q_lowpass_m: float = None, q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, + tv_denoise: bool = True, tv_denoise_weight: float = None, tv_denoise_inner_iter: float = 40, object_positivity: bool = True, @@ -1150,8 +1150,8 @@ def reconstruct( Parameters -------- - max_iter: int, optional - Maximum number of iterations to run + num_iter: int, optional + Number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: "generalized-projections", @@ -1178,28 +1178,28 @@ def reconstruct( Probe normalization minimum as a fraction of the maximum overlap intensity positions_step_size: float, optional Positions update step size - pure_phase_object_iter: float, optional - Number of iterations where object amplitude is set to unity + pure_phase_object: bool, optional + If True, object amplitude is set to unity fix_probe_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: bool, optional + If True, probe-positions are fixed fix_positions_com: bool, optional If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional @@ -1212,10 +1212,10 @@ def reconstruct( Standard deviation of gaussian kernel for electrostatic object in A gaussian_filter_sigma_m: float Standard deviation of gaussian kernel for magnetic object in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: bool Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool @@ -1226,8 +1226,8 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass_e: float Cut-off frequency in A^-1 for low-pass filtering electrostatic object q_lowpass_m: float @@ -1238,8 +1238,8 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. tv_denoise_inner_iter: float @@ -1320,7 +1320,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, switch_object_iter, use_projection_scheme, reconstruction_method, @@ -1347,12 +1347,9 @@ def reconstruct( if q_lowpass_m is None: q_lowpass_m = q_lowpass_e - if fix_positions_iter < 1: - fix_positions_iter = 1 # give position correction a chance - # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -1450,13 +1447,13 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) object_update -= self._object # position correction - if a0 >= fix_positions_iter: + if not fix_positions and a0 > 0: self._positions_px_all[ batch_indices ] = self._position_correction( @@ -1493,32 +1490,29 @@ def reconstruct( # probe and positions _probe = self._probe_constraints( _probe, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=_probe_initial_aperture, ) self._positions_px_all[batch_indices] = self._positions_constraints( self._positions_px_all[batch_indices], self._positions_px_initial_all[batch_indices], - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com - and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, ) @@ -1533,41 +1527,37 @@ def reconstruct( _probe, self._positions_px_all[batch_indices], self._positions_px_initial_all[batch_indices], - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=_probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com - and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter + gaussian_filter=gaussian_filter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass_m is not None or q_highpass_m is not None), q_lowpass_e=q_lowpass_e, q_lowpass_m=q_lowpass_m, q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weight is not None, + tv_denoise=tv_denoise and tv_denoise_weight is not None, tv_denoise_weight=tv_denoise_weight, tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, @@ -1576,7 +1566,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - pure_phase_object=a0 < pure_phase_object_iter + pure_phase_object=pure_phase_object and self._object_type == "complex", ) @@ -1589,18 +1579,18 @@ def reconstruct( # object only self._object = self._object_constraints( self._object, - gaussian_filter=a0 < gaussian_filter_iter + gaussian_filter=gaussian_filter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass_m is not None or q_highpass_m is not None), q_lowpass_e=q_lowpass_e, q_lowpass_m=q_lowpass_m, q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise=tv_denoise and tv_denoise_weight is not None, tv_denoise_weight=tv_denoise_weight, tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, @@ -1609,7 +1599,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - pure_phase_object=a0 < pure_phase_object_iter + pure_phase_object=pure_phase_object and self._object_type == "complex", ) From 6e37d76d100652b32472793b5035ee5d582bbb36 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 19:27:45 +0000 Subject: [PATCH 120/128] ptycho tomo flags Former-commit-id: c11102d28593cf34709cad4c779f7878dfaca8c6 --- .../process/phase/ptychographic_tomography.py | 127 +++++++++--------- 1 file changed, 60 insertions(+), 67 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 67b57afcb..8c156f384 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -758,7 +758,7 @@ def preprocess( def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -770,35 +770,35 @@ def reconstruct( normalization_min: float = 1, positions_step_size: float = 0.9, fix_probe_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, + tv_denoise: bool = True, + tv_denoise_weights: float = None, tv_denoise_inner_iter=40, collective_measurement_updates: bool = True, store_iterations: bool = False, @@ -812,8 +812,8 @@ def reconstruct( Parameters -------- - max_iter: int, optional - Maximum number of iterations to run + num_iter: int, optional + Number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: "generalized-projections", @@ -842,24 +842,24 @@ def reconstruct( Positions update step size fix_probe_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: bool, optional + If True, probe-positions are fixed fix_positions_com: bool, optional If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional @@ -870,10 +870,10 @@ def reconstruct( If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: bool Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool @@ -884,8 +884,8 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float @@ -894,8 +894,8 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive - tv_denoise: bool - If True, applies TV denoising on object + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. @@ -960,7 +960,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, np.inf, use_projection_scheme, reconstruction_method, @@ -983,7 +983,7 @@ def reconstruct( # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -1089,11 +1089,11 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) # position correction - if a0 >= fix_positions_iter: + if not fix_positions: self._positions_px_all[ batch_indices ] = self._position_correction( @@ -1141,32 +1141,29 @@ def reconstruct( # probe and positions _probe = self._probe_constraints( _probe, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=_probe_initial_aperture, ) self._positions_px_all[batch_indices] = self._positions_constraints( self._positions_px_all[batch_indices], self._positions_px_initial_all[batch_indices], - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com - and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, ) @@ -1181,32 +1178,29 @@ def reconstruct( _probe, self._positions_px_all[batch_indices], self._positions_px_initial_all[batch_indices], - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=_probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com - and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, @@ -1217,8 +1211,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weights is not None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, ) @@ -1234,10 +1227,10 @@ def reconstruct( # object only self._object = self._object_constraints( self._object, - gaussian_filter=a0 < gaussian_filter_iter + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, @@ -1248,7 +1241,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, ) From 6e5970c60702e7fa4be9a790394e0e36d4ea423a Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 20:01:33 +0000 Subject: [PATCH 121/128] magnetic ptycho tomo flags Former-commit-id: 88c509de119299b8c4acbe81c513c693e03862a3 --- .../magnetic_ptychographic_tomography.py | 128 ++++++++---------- 1 file changed, 59 insertions(+), 69 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 13ab8d846..80438e50a 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -835,7 +835,7 @@ def _constraints(self, current_object, current_probe, current_positions, **kwarg def reconstruct( self, - max_iter: int = 8, + num_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -847,28 +847,28 @@ def reconstruct( normalization_min: float = 1, positions_step_size: float = 0.9, fix_probe_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, constrain_probe_amplitude_relative_radius: float = 0.5, constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude: bool = False, constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, + fix_positions: bool = True, fix_positions_com: bool = True, max_position_update_distance: float = None, max_position_total_distance: float = None, - global_affine_transformation: bool = True, + global_affine_transformation: bool = False, gaussian_filter_sigma_e: float = None, gaussian_filter_sigma_m: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, fit_probe_aberrations_max_angular_order: int = 4, fit_probe_aberrations_max_radial_order: int = 4, fit_probe_aberrations_remove_initial: bool = False, fit_probe_aberrations_using_scikit_image: bool = True, - butterworth_filter_iter: int = np.inf, + butterworth_filter: bool = True, q_lowpass_e: float = None, q_lowpass_m: float = None, q_highpass_e: float = None, @@ -877,7 +877,7 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - tv_denoise_iter=np.inf, + tv_denoise: bool = True, tv_denoise_weights=None, tv_denoise_inner_iter=40, collective_measurement_updates: bool = True, @@ -892,8 +892,8 @@ def reconstruct( Parameters -------- - max_iter: int, optional - Maximum number of iterations to run + num_iter: int, optional + Number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: "generalized-projections", @@ -922,24 +922,24 @@ def reconstruct( Positions update step size fix_probe_com: bool, optional If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency constrain_probe_fourier_amplitude_max_width_pixels: float Maximum pixel width of fitted sigmoid functions. constrain_probe_fourier_amplitude_constant_intensity: bool If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate + fix_positions: bool, optional + If True, probe-positions are fixed fix_positions_com: bool, optional If True, fixes the positions CoM to the middle of the fov max_position_update_distance: float, optional @@ -952,10 +952,10 @@ def reconstruct( Standard deviation of gaussian kernel for electrostatic object in A gaussian_filter_sigma_m: float Standard deviation of gaussian kernel for magnetic object in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion fit_probe_aberrations_max_angular_order: bool Max angular order of probe aberrations basis functions fit_probe_aberrations_max_radial_order: bool @@ -966,8 +966,8 @@ def reconstruct( If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads to a documented bug where the kernel hangs.. If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float @@ -976,8 +976,8 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive - tv_denoise: bool - If True, applies TV denoising on object + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. @@ -1053,7 +1053,7 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( - max_iter, + num_iter, np.inf, use_projection_scheme, reconstruction_method, @@ -1080,12 +1080,9 @@ def reconstruct( if q_lowpass_m is None: q_lowpass_m = q_lowpass_e - if fix_positions_iter < 1: - fix_positions_iter = 1 # give position correction a chance - # main loop for a0 in tqdmnd( - max_iter, + num_iter, desc="Reconstructing object and probe", unit=" iter", disable=not progress_bar, @@ -1200,11 +1197,11 @@ def reconstruct( use_projection_scheme=use_projection_scheme, step_size=step_size, normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, + fix_probe=fix_probe, ) # position correction - if a0 >= fix_positions_iter: + if not fix_positions and a0 > 0: self._positions_px_all[ batch_indices ] = self._position_correction( @@ -1255,32 +1252,29 @@ def reconstruct( # probe and positions _probe = self._probe_constraints( _probe, - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture, initial_probe_aperture=_probe_initial_aperture, ) self._positions_px_all[batch_indices] = self._positions_constraints( self._positions_px_all[batch_indices], self._positions_px_initial_all[batch_indices], - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com - and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, ) @@ -1295,33 +1289,30 @@ def reconstruct( _probe, self._positions_px_all[batch_indices], self._positions_px_initial_all[batch_indices], - fix_probe_com=fix_probe_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=a0 < fix_probe_aperture_iter, + fix_probe_aperture=fix_probe_aperture, initial_probe_aperture=_probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - fix_positions_com=fix_positions_com - and a0 >= fix_positions_iter, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter + gaussian_filter=gaussian_filter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass_m is not None or q_highpass_m is not None), q_lowpass_e=q_lowpass_e, q_lowpass_m=q_lowpass_m, @@ -1334,8 +1325,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weights is not None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, ) @@ -1351,11 +1341,11 @@ def reconstruct( # object only self._object = self._object_constraints_vector( self._object, - gaussian_filter=a0 < gaussian_filter_iter + gaussian_filter=gaussian_filter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass_m is not None or q_highpass_m is not None), q_lowpass_e=q_lowpass_e, q_lowpass_m=q_lowpass_m, @@ -1368,7 +1358,7 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, ) From 0738e87c08f595475a36f7197f55023387e52ce1 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 20:06:34 +0000 Subject: [PATCH 122/128] stricter flags Former-commit-id: 98b1b641e54830f08249c2c7830bd773e38b148d --- py4DSTEM/process/phase/magnetic_ptychographic_tomography.py | 4 ++-- py4DSTEM/process/phase/mixedstate_multislice_ptychography.py | 2 +- py4DSTEM/process/phase/mixedstate_ptychography.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index 80438e50a..faa3fb9be 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1266,7 +1266,7 @@ def reconstruct( fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=fix_probe_aperture, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=_probe_initial_aperture, ) @@ -1303,7 +1303,7 @@ def reconstruct( fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=fix_probe_aperture, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=_probe_initial_aperture, fix_positions=fix_positions, fix_positions_com=fix_positions_com and not fix_positions, diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 907847cf3..d5a731931 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -1041,7 +1041,7 @@ def reconstruct( fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=fix_probe_aperture, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=self._probe_initial_aperture, fix_positions=fix_positions, fix_positions_com=fix_positions_com and not fix_positions, diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index f5c5e6cd3..4f2fb7be0 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -931,7 +931,7 @@ def reconstruct( fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=fix_probe_aperture, + fix_probe_aperture=fix_probe_aperture and not fix probe, initial_probe_aperture=self._probe_initial_aperture, fix_positions=fix_positions, fix_positions_com=fix_positions_com and not fix_positions, From 0d3f7cdcde913205e73df64607828334bbdaf0f3 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 20:10:23 +0000 Subject: [PATCH 123/128] perhaps i forgot an underscore Former-commit-id: d1197e81d2e6cc92a9f61c1ad4c33f413b74eb15 --- py4DSTEM/process/phase/mixedstate_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 4f2fb7be0..3a280e8d4 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -931,7 +931,7 @@ def reconstruct( fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, - fix_probe_aperture=fix_probe_aperture and not fix probe, + fix_probe_aperture=fix_probe_aperture and not fix_probe, initial_probe_aperture=self._probe_initial_aperture, fix_positions=fix_positions, fix_positions_com=fix_positions_com and not fix_positions, From bc5f58094028de59c174fbd2184ca17dd9ce7a2f Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 20:30:23 +0000 Subject: [PATCH 124/128] small parallax change Former-commit-id: c2299162e69dbfd2daa8258a2eae32df73211b5d --- py4DSTEM/process/phase/parallax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 3156c665b..58181d812 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -735,7 +735,7 @@ def reconstruct( self, max_alignment_bin: int = None, min_alignment_bin: int = 1, - max_iter_at_min_bin: int = 2, + num_iter_at_min_bin: int = 2, alignment_bin_values: list = None, cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), @@ -759,7 +759,7 @@ def reconstruct( If None, the bright field disk radius is used min_alignment_bin: int, optional Minimum bin size for bright field alignment - max_iter_at_min_bin: int, optional + num_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size alignment_bin_values: list, optional If not None, explicitly sets the iteration bin values @@ -868,9 +868,9 @@ def reconstruct( bin_max = np.ceil(np.log(max_alignment_bin) / np.log(2)) bin_vals = 2 ** np.arange(bin_min, bin_max)[::-1] - if max_iter_at_min_bin > 1: + if num_iter_at_min_bin > 1: bin_vals = np.hstack( - (bin_vals, np.repeat(bin_vals[-1], max_iter_at_min_bin - 1)) + (bin_vals, np.repeat(bin_vals[-1], num_iter_at_min_bin - 1)) ) if plot_aligned_bf: From 0bc392a3519a73024f9df8e8620e8cc045d57e96 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 20:55:49 +0000 Subject: [PATCH 125/128] small read-write changes Former-commit-id: e733076d596654f002ae8d3c971e8b54d6bb3328 --- py4DSTEM/process/phase/phase_base_class.py | 37 +++++++++++----------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index ca1b7c332..a4e6cac6e 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -148,7 +148,13 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self - def reinitialize_parameters(self, device: str = None, verbose: bool = None): + def reinitialize_parameters( + self, + device: str = None, + storage: str = None, + clear_fft_cache: bool = None, + verbose: bool = None, + ): """ Reinitializes common parameters. This is useful when loading a previously-saved reconstruction (which set device='cpu' and verbose=True for compatibility) , @@ -157,7 +163,11 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): Parameters ---------- device: str, optional - If not None, imports and assigns appropriate device modules + If not None, assigns appropriate device modules + storage: str, optional + If not None, assigns appropriate storage modules + clear_fft_cache: bool, optional + If not None, sets the FFT caching parameter verbose: bool, optional If not None, sets the verbosity to verbose @@ -168,23 +178,10 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): """ if device is not None: - if device == "cpu": - import scipy - - self._xp = np - self._asnumpy = np.asarray - self._scipy = scipy + self.set_device(device, clear_fft_cache) - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._asnumpy = cp.asnumpy - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - self._device = device + if storage is not None: + self.set_storage(storage) if verbose is not None: self._verbose = verbose @@ -1508,6 +1505,8 @@ def to_h5(self, group): "object_type": self._object_type, "verbose": self._verbose, "device": self._device, + "storage": self._storage, + "clear_fft_cache": self._clear_fft_cache, "name": self.name, "vacuum_probe_intensity": vacuum_probe_intensity, "positions": scan_positions, @@ -1656,6 +1655,8 @@ def _get_constructor_args(cls, group): "polar_parameters": polar_params, "verbose": True, # for compatibility "device": "cpu", # for compatibility + "storage": "cpu", # for compatibility + "clear_fft_cache": True, # for compatibility } class_specific_kwargs = {} From 98493efeed185c1c8c1ab560ccfb5389d396a81c Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 21:17:50 +0000 Subject: [PATCH 126/128] removing switch_obj_iter Former-commit-id: 306c20dc4c664789db3057dea277ea310495db64 --- py4DSTEM/process/phase/phase_base_class.py | 35 +++++++++++++------ .../process/phase/singleslice_ptychography.py | 24 +++++-------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index a4e6cac6e..a67640d9b 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1710,6 +1710,29 @@ def _populate_instance(self, group): self.probe = self.probe_centered self._preprocessed = True + def _switch_object_type(self, object_type): + """ + Switches object type to/from "potential"/"complex" + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + xp = self._xp + + match (self._object_type, object_type): + case ("potential", "complex"): + self._object_type = "complex" + self._object = xp.exp(1j * self._object, dtype=xp.complex64) + case ("complex", "potential"): + self._object_type = "potential" + self._object = xp.angle(self._object) + case _: + self._object_type = self._object_type + + return self + def _set_polar_parameters(self, parameters: dict): """ Set the probe aberrations dictionary. @@ -2032,7 +2055,6 @@ def _set_reconstruction_method_parameters( def _report_reconstruction_summary( self, max_iter, - switch_object_iter, use_projection_scheme, reconstruction_method, reconstruction_parameter, @@ -2046,16 +2068,7 @@ def _report_reconstruction_summary( """ """ # object type - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " # stochastic gradient descent if max_batch_size is not None: diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index 1c73e2ade..71fe65cd7 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -606,12 +606,12 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, device: str = None, clear_fft_cache: bool = None, + object_type: str = None, ): """ Ptychographic reconstruction main method. @@ -712,9 +712,6 @@ def reconstruct( Phase shift in radians to be subtracted from the potential at each iteration fix_potential_baseline: bool If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' store_iterations: bool, optional If True, reconstructed objects and probes are stored at each iteration progress_bar: bool, optional @@ -722,9 +719,11 @@ def reconstruct( reset: bool, optional If True, previous reconstructions are ignored device: str, optional - if not none, overwrites self._device to set device preprocess will be perfomed on. + If not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional - if true, and device = 'gpu', clears the cached fft plan at the end of function calls + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type Returns -------- @@ -745,7 +744,9 @@ def reconstruct( ] self.copy_attributes_to_device(attrs, device) - xp = self._xp + if object_type is not None: + self._switch_object_type(object_type) + xp_storage = self._xp_storage device = self._device asnumpy = self._asnumpy @@ -770,7 +771,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - switch_object_iter, use_projection_scheme, reconstruction_method, reconstruction_parameter, @@ -802,14 +802,6 @@ def reconstruct( ): error = 0.0 - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object, dtype=xp.complex64) - else: - self._object_type = "potential" - self._object = xp.angle(self._object) - # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) From 3a4ff9ab25a8814d86241c8989624bc418cfdc28 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 21 Jan 2024 21:25:34 +0000 Subject: [PATCH 127/128] remore switch_iter from all classes Former-commit-id: c4691d3cde7919cf5bca55a13c195199f6f20cbc --- .../magnetic_ptychographic_tomography.py | 1 - .../process/phase/magnetic_ptychography.py | 19 ++++++------------ .../mixedstate_multislice_ptychography.py | 20 ++++++------------- .../process/phase/mixedstate_ptychography.py | 20 ++++++------------- .../process/phase/multislice_ptychography.py | 20 ++++++------------- .../process/phase/ptychographic_tomography.py | 1 - 6 files changed, 24 insertions(+), 57 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index faa3fb9be..7249a9064 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -1054,7 +1054,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - np.inf, use_projection_scheme, reconstruction_method, reconstruction_parameter, diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 35bbb8690..d718b1a9e 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -1137,13 +1137,13 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, store_iterations: bool = False, collective_measurement_updates: bool = True, progress_bar: bool = True, reset: bool = None, device: str = None, clear_fft_cache: bool = None, + object_type: str = None, ): """ Ptychographic reconstruction main method. @@ -1250,9 +1250,6 @@ def reconstruct( Phase shift in radians to be subtracted from the potential at each iteration fix_potential_baseline: bool If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' store_iterations: bool, optional If True, reconstructed objects and probes are stored at each iteration collective_measurement_updates: bool @@ -1265,6 +1262,8 @@ def reconstruct( if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional if true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type Returns -------- @@ -1285,6 +1284,9 @@ def reconstruct( ] self.copy_attributes_to_device(attrs, device) + if object_type is not None: + self._switch_object_type(object_type) + xp = self._xp xp_storage = self._xp_storage device = self._device @@ -1321,7 +1323,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - switch_object_iter, use_projection_scheme, reconstruction_method, reconstruction_parameter, @@ -1356,14 +1357,6 @@ def reconstruct( ): error = 0.0 - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - else: - self._object_type = "potential" - self._object = xp.angle(self._object) - if collective_measurement_updates: collective_object = xp.zeros_like(self._object) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index d5a731931..47dd67dd3 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -732,12 +732,12 @@ def reconstruct( tv_denoise: bool = True, tv_denoise_weights=None, tv_denoise_inner_iter=40, - switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, device: str = None, clear_fft_cache: bool = None, + object_type: str = None, ): """ Ptychographic reconstruction main method. @@ -849,15 +849,14 @@ def reconstruct( the more denoising. tv_denoise_inner_iter: float Number of iterations to run in inner loop of TV denoising - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' store_iterations: bool, optional If True, reconstructed objects and probes are stored at each iteration progress_bar: bool, optional If True, reconstruction progress is displayed reset: bool, optional If True, previous reconstructions are ignored + object_type: str, optional + Overwrites self._object_type Returns -------- @@ -879,7 +878,9 @@ def reconstruct( ] self.copy_attributes_to_device(attrs, device) - xp = self._xp + if object_type is not None: + self._switch_object_type(object_type) + xp_storage = self._xp_storage device = self._device asnumpy = self._asnumpy @@ -904,7 +905,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - switch_object_iter, use_projection_scheme, reconstruction_method, reconstruction_parameter, @@ -936,14 +936,6 @@ def reconstruct( ): error = 0.0 - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - else: - self._object_type = "potential" - self._object = xp.angle(self._object) - # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index 3a280e8d4..6fbb72b5d 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -630,12 +630,12 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, device: str = None, clear_fft_cache: bool = None, + object_type: str = None, ): """ Ptychographic reconstruction main method. @@ -736,9 +736,6 @@ def reconstruct( Phase shift in radians to be subtracted from the potential at each iteration fix_potential_baseline: bool If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' store_iterations: bool, optional If True, reconstructed objects and probes are stored at each iteration progress_bar: bool, optional @@ -749,6 +746,8 @@ def reconstruct( If not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type Returns -------- @@ -769,7 +768,9 @@ def reconstruct( ] self.copy_attributes_to_device(attrs, device) - xp = self._xp + if object_type is not None: + self._switch_object_type(object_type) + xp_storage = self._xp_storage device = self._device asnumpy = self._asnumpy @@ -794,7 +795,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - switch_object_iter, use_projection_scheme, reconstruction_method, reconstruction_parameter, @@ -826,14 +826,6 @@ def reconstruct( ): error = 0.0 - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - else: - self._object_type = "potential" - self._object = xp.angle(self._object) - # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index 23808c535..03e636f57 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -705,12 +705,12 @@ def reconstruct( tv_denoise: bool = True, tv_denoise_weights=None, tv_denoise_inner_iter=40, - switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, device: str = None, clear_fft_cache: bool = None, + object_type: str = None, ): """ Ptychographic reconstruction main method. @@ -824,9 +824,6 @@ def reconstruct( the more denoising. tv_denoise_inner_iter: float Number of iterations to run in inner loop of TV denoising - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' store_iterations: bool, optional If True, reconstructed objects and probes are stored at each iteration progress_bar: bool, optional @@ -837,6 +834,8 @@ def reconstruct( If not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type Returns -------- @@ -858,7 +857,9 @@ def reconstruct( ] self.copy_attributes_to_device(attrs, device) - xp = self._xp + if object_type is not None: + self._switch_object_type(object_type) + xp_storage = self._xp_storage device = self._device asnumpy = self._asnumpy @@ -883,7 +884,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - switch_object_iter, use_projection_scheme, reconstruction_method, reconstruction_parameter, @@ -915,14 +915,6 @@ def reconstruct( ): error = 0.0 - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - else: - self._object_type = "potential" - self._object = xp.angle(self._object) - # randomize if not use_projection_scheme: np.random.shuffle(shuffled_indices) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 8c156f384..b4a29fa5d 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -961,7 +961,6 @@ def reconstruct( if self._verbose: self._report_reconstruction_summary( num_iter, - np.inf, use_projection_scheme, reconstruction_method, reconstruction_parameter, From 88f0071861800bed2c3c4795beb00a7e62337210 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 21 Jan 2024 14:56:28 -0800 Subject: [PATCH 128/128] show_fft bug Former-commit-id: 8b752aa4386839c6eec0ef00cf388e8c0e8e0d88 --- py4DSTEM/process/phase/ptychographic_methods.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 5401b1393..6ab349f30 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -177,7 +177,6 @@ def show_object_fft( self, obj=None, apply_hanning_window=True, - crop_to_min_frequency=True, scalebar=True, pixelsize=None, pixelunits=None, @@ -192,8 +191,6 @@ def show_object_fft( If None is specified, uses the `object_fft` property apply_hanning_window: bool, optional If True, a 2D Hann window is applied to the object before FFT - crop_to_min_frequency: bool, optional - If True, a square FFT is plotted, cropping to the smallest axis scalebar: bool, optional if True, adds scalebar to probe pixelunits: str, optional @@ -211,13 +208,6 @@ def show_object_fft( if pixelunits is None: pixelunits = r"$\AA^{-1}$" - if crop_to_min_frequency: - sx, sy = object_fft.shape - s = min(sx, sy) - start_x = sx // 2 - (s // 2) - start_y = sy // 2 - (s // 2) - object_fft = object_fft[start_x : start_x + s, start_y : start_y + s] - figsize = kwargs.pop("figsize", (4, 4)) cmap = kwargs.pop("cmap", "magma") ticks = kwargs.pop("ticks", False) @@ -239,6 +229,7 @@ def show_object_fft( pixelunits=pixelunits, vmin=vmin, vmax=vmax, + aspect=object_fft.shape[1] / object_fft.shape[0], **kwargs, )