From 01d6ad6d6fb2bb0628d886e47115c35b1fbdb5ff Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 09:57:01 -0800 Subject: [PATCH] moved single-slice forward and adjoint to methods.py --- .../phase/iterative_ptychographic_methods.py | 457 ++++++++++++++++++ .../iterative_singleslice_ptychography.py | 451 +---------------- 2 files changed, 459 insertions(+), 449 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index a2e968473..ac16d89a5 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -7,6 +7,7 @@ from py4DSTEM.process.phase.utils import ( AffineTransform, ComplexProbe, + fft_shift, rotate_point, spatial_frequencies, ) @@ -1298,3 +1299,459 @@ def _probe(self): probe += pr return probe / self._num_tilts + + +class ObjectNDProbeMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. + """ + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + self._vectorized_patch_indices_row, self._vectorized_patch_indices_col + ] + + overlap = shifted_probes * object_patches + + return shifted_probes, object_patches, overlap + + def _gradient_descent_fourier_projection(self, amplitudes, overlap): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + fourier_overlap = xp.fft.fft2(overlap) + error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) + + fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_overlap - overlap + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + fourier_overlap = xp.fft.fft2(overlap) + error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + fourier_projected_factor = amplitudes * xp.exp( + 1j * xp.angle(fourier_projected_factor) + ) + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + object * probe overlap + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + shifted_probes, object_patches, overlap = self._overlap_projection( + current_object, current_probe + ) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + overlap, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, overlap + ) + + return shifted_probes, object_patches, overlap, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2 + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ) + ) + * probe_normalization + ) + else: + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2 + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ) + ) + * probe_normalization + ) + else: + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index acd570424..7d054a6de 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -27,6 +27,7 @@ ) from py4DSTEM.process.phase.iterative_ptychographic_methods import ( ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ) from py4DSTEM.process.phase.utils import ( @@ -44,6 +45,7 @@ class SingleslicePtychographicReconstruction( PositionsConstraintsMixin, ProbeConstraintsMixin, ObjectNDConstraintsMixin, + ObjectNDProbeMethodsMixin, ProbeMethodsMixin, ObjectNDMethodsMixin, PtychographicReconstruction, @@ -524,455 +526,6 @@ def preprocess( return self - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * object_patches - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - def _constraints( self, current_object,