From f23d8cc83f33d5b7b50b3ee4dbc9689e53707c66 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 13:36:05 -0800 Subject: [PATCH] correctly handling collective updates constraints --- .../phase/iterative_magnetic_ptychography.py | 222 ++++-------------- .../iterative_ptychographic_tomography.py | 58 +++-- 2 files changed, 81 insertions(+), 199 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py index b97820e04..c2bbce64e 100644 --- a/py4DSTEM/process/phase/iterative_magnetic_ptychography.py +++ b/py4DSTEM/process/phase/iterative_magnetic_ptychography.py @@ -961,131 +961,28 @@ def _gradient_descent_adjoint( return current_object, current_probe - def _constraints( + def _object_constraints( self, current_object, - current_probe, - current_positions, pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, gaussian_filter, gaussian_filter_sigma_e, gaussian_filter_sigma_m, butterworth_filter, + butterworth_order, q_lowpass_e, q_lowpass_m, q_highpass_e, q_highpass_m, - butterworth_order, tv_denoise, tv_denoise_weight, tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, + **kwargs, ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - warmup_iteration: bool - If True, constraints electrostatic object only - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - # object constraints + """MagneticObjectNDConstraints wrapper function""" # smoothness if gaussian_filter: @@ -1132,56 +1029,7 @@ def _constraints( elif object_positivity: current_object[0] = self._object_positivity_constraint(current_object[0]) - # probe constraints - - # CoM corner-centering - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # Fourier amplitude (aperture) constraints - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - # Fourier phase (aberrations) fitting - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - # Real-space amplitude constraint - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - # position constraints - if not fix_positions: - # CoM centering - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - # global affine transformation - # TO-DO: generalize to higher-order basis? - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions + return current_object def reconstruct( self, @@ -1351,6 +1199,12 @@ def reconstruct( asnumpy = self._asnumpy xp = self._xp + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + # set and report reconstruction method ( use_projection_scheme, @@ -1370,7 +1224,7 @@ def reconstruct( if use_projection_scheme: raise NotImplementedError( - "Magnetic ptychography currently only implemented for gradient descent." + "Magnetic ptychography is currently only implemented for gradient descent." ) if self._verbose: @@ -1543,7 +1397,38 @@ def reconstruct( unshuffled_indices ] - if not collective_measurement_updates: + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[ + start_idx:end_idx + ] = self._positions_constraints( + self._positions_px_all[start_idx:end_idx], + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions ( self._object, _probe, @@ -1601,24 +1486,9 @@ def reconstruct( if collective_measurement_updates: self._object += collective_object / self._num_measurements - self._object, _, _ = self._constraints( + # object only + self._object = self._object_constraints( self._object, - None, - None, - fix_com=False, - constrain_probe_amplitude=False, - constrain_probe_amplitude_relative_radius=None, - constrain_probe_amplitude_relative_width=None, - constrain_probe_fourier_amplitude=False, - constrain_probe_fourier_amplitude_max_width_pixels=None, - constrain_probe_fourier_amplitude_constant_intensity=None, - fit_probe_aberrations=False, - fit_probe_aberrations_max_angular_order=None, - fit_probe_aberrations_max_radial_order=None, - fix_probe_aperture=False, - initial_probe_aperture=None, - fix_positions=True, - global_affine_transformation=None, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma_m is not None, gaussian_filter_sigma_e=gaussian_filter_sigma_e, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index 5b4892c5e..c4374c93a 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -738,7 +738,7 @@ def reconstruct( tv_denoise_iter=np.inf, tv_denoise_weights=None, tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, + collective_tilt_updates: bool = True, store_iterations: bool = False, progress_bar: bool = True, reset: bool = None, @@ -1045,7 +1045,38 @@ def reconstruct( unshuffled_indices ] - if not collective_tilt_updates: + if collective_tilt_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[ + start_tilt:end_tilt + ] = self._positions_constraints( + self._positions_px_all[start_tilt:end_tilt], + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions ( self._object, _probe, @@ -1100,28 +1131,9 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - ( - self._object, - _, - _, - ) = self._constraints( + # object only + self._object = self._object_constraints( self._object, - None, - None, - fix_com=False, - constrain_probe_amplitude=False, - constrain_probe_amplitude_relative_radius=None, - constrain_probe_amplitude_relative_width=None, - constrain_probe_fourier_amplitude=False, - constrain_probe_fourier_amplitude_max_width_pixels=None, - constrain_probe_fourier_amplitude_constant_intensity=None, - fit_probe_aberrations=False, - fit_probe_aberrations_max_angular_order=None, - fit_probe_aberrations_max_radial_order=None, - fix_probe_aperture=False, - initial_probe_aperture=None, - fix_positions=True, - global_affine_transformation=None, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma,