From 14c1e66341db84c695401ce824cedb055ce4ec1e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 1 Jan 2024 18:31:00 -0800 Subject: [PATCH] cleaned up overlap tomo reconstruct, different probes per tilt --- .../phase/iterative_overlap_tomography.py | 242 ++++-------------- .../phase/iterative_ptychographic_methods.py | 125 ++++++--- 2 files changed, 142 insertions(+), 225 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 1ef2d54cb..782739d74 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -516,8 +516,8 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial = _probe.copy() - self._probes_all_initial_aperture = xp.abs(xp.fft.fft2(_probe)) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init @@ -1034,7 +1034,7 @@ def _constraints( def reconstruct( self, - max_iter: int = 64, + max_iter: int = 8, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, reconstruction_parameter_a: float = None, @@ -1181,148 +1181,41 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: - raise NotImplementedError( - "Position correction is currently incompatible with collective updates." + self._report_reconstruction_summary( + max_iter, + np.inf, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, ) - # Batching + # batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) if max_batch_size is not None: xp.random.seed(seed_random) @@ -1330,37 +1223,7 @@ def reconstruct( max_batch_size = self._num_diffraction_patterns # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) # main loop for a0 in tqdmnd( @@ -1393,6 +1256,12 @@ def reconstruct( object_sliced = self._project_sliced_object( self._object, self._num_slices ) + + _probe = self._probes_all[self._active_tilt_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_tilt_index + ] + if not use_projection_scheme: object_sliced_old = object_sliced.copy() @@ -1440,14 +1309,14 @@ def reconstruct( # forward operator ( - propagated_probes, + shifted_probes, object_patches, - transmitted_probes, + overlap, self._exit_waves, batch_error, ) = self._forward( object_sliced, - self._probe, + _probe, amplitudes, self._exit_waves, use_projection_scheme, @@ -1457,11 +1326,11 @@ def reconstruct( ) # adjoint operator - object_sliced, self._probe = self._adjoint( + object_sliced, _probe = self._adjoint( object_sliced, - self._probe, + _probe, object_patches, - propagated_probes, + shifted_probes, self._exit_waves, use_projection_scheme=use_projection_scheme, step_size=step_size, @@ -1473,8 +1342,8 @@ def reconstruct( if a0 >= fix_positions_iter: positions_px[start:end] = self._position_correction( object_sliced, - self._probe, - transmitted_probes, + _probe, + overlap, amplitudes, self._positions_px, positions_step_size, @@ -1514,11 +1383,11 @@ def reconstruct( if not collective_tilt_updates: ( self._object, - self._probe, + _probe, self._positions_px_all[start_tilt:end_tilt], ) = self._constraints( self._object, - self._probe, + _probe, self._positions_px_all[start_tilt:end_tilt], fix_com=fix_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter @@ -1535,7 +1404,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, + initial_probe_aperture=_probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1568,11 +1437,11 @@ def reconstruct( ( self._object, - self._probe, + _probe, _, ) = self._constraints( self._object, - self._probe, + _probe, None, fix_com=fix_com and a0 >= fix_probe_iter, constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter @@ -1589,7 +1458,7 @@ def reconstruct( fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, + initial_probe_aperture=_probe_initial_aperture, fix_positions=True, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1612,6 +1481,7 @@ def reconstruct( ) self.error_iterations.append(error.item()) + if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) self.probe_iterations.append(self.probe_centered) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_methods.py b/py4DSTEM/process/phase/iterative_ptychographic_methods.py index 172eacd57..f58507cf2 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_methods.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_methods.py @@ -174,6 +174,45 @@ def show_object_fft(self, obj=None, **kwargs): **kwargs, ) + def _reset_reconstruction( + self, + store_iterations, + reset, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._object_type = self._object_type_initial + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + self._exit_waves = None + @property def object_fft(self): """Fourier transform of current object estimate""" @@ -1291,6 +1330,53 @@ class ProbeListMethodsMixin: Overwrites ProbeMethodsMixin. """ + def _reset_reconstruction( + self, + store_iterations, + reset, + use_projection_scheme, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probes_all = [pr.copy() for pr in self._probes_all_initial] + self._positions_px_all = self._positions_px_initial_all.copy() + self._object_type = self._object_type_initial + + if use_projection_scheme: + self._exit_waves = [None] * self._num_tilts + else: + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * self._num_tilts + else: + self._exit_waves = None + @property def _probe(self): """Dummy property to return average probe""" @@ -1759,45 +1845,6 @@ def _adjoint( return current_object, current_probe - def _reset_reconstruction( - self, - store_iterations, - reset, - ): - """ """ - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - # reset can be True, False, or None (default) - if reset is True: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._object_type = self._object_type_initial - self._exit_waves = None - - # delete positions affine transform - if hasattr(self, "_tf"): - del self._tf - - elif reset is None: - # continued run - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - - # first start - else: - self.error_iterations = [] - self._exit_waves = None - class Object2p5DProbeMethodsMixin: """