From 85d1b76b9aca83e01a0842c8f3ff01aa0d619717 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 5 Jan 2024 13:02:51 -0800 Subject: [PATCH] splitting up constraints --- .../process/phase/iterative_base_class.py | 9 + ...tive_mixedstate_multislice_ptychography.py | 209 +----------- .../iterative_mixedstate_ptychography.py | 172 +--------- .../iterative_multislice_ptychography.py | 219 +----------- .../iterative_ptychographic_constraints.py | 322 ++++++++++++++++++ .../iterative_ptychographic_tomography.py | 177 +--------- .../iterative_singleslice_ptychography.py | 195 +---------- 7 files changed, 336 insertions(+), 967 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 3a066de4e..f06b3aa29 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1980,6 +1980,15 @@ def _report_reconstruction_summary( ) ) + def _constraints(self, current_object, current_probe, current_positions, **kwargs): + """Wrapper function for all classes to inherit""" + + current_object = self._object_constraints(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints(current_positions, **kwargs) + + return current_object, current_probe, current_positions + @property def angular_sampling(self): """Angular sampling [mrad]""" diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index b00fea11b..79930c740 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -647,213 +647,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - orthogonalize_probe, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: int - If not None, pads object at top and bottom with this many zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - orthogonalize_probe: bool - If True, probe will be orthogonalized - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, - kz_regularization_gamma, - z_padding=1, - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - z_padding=1, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - padding=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -1058,8 +851,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # batching diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 317f9316c..765f6b8b6 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -550,176 +550,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - orthogonalize_probe, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray - initial probe aperture to use in replacing probe fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - orthogonalize_probe: bool - If True, probe will be orthogonalized - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -905,8 +735,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # Batching diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 6bb14d91a..91f9ea076 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -622,223 +622,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: int - if not None, pads object at top and bottom with this many zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, - kz_regularization_gamma, - z_padding=1, - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - z_padding=1, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - padding=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -1042,8 +825,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # Batching diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 538b9d7aa..5181984ed 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -408,6 +408,61 @@ def _object_denoise_tv_chambolle( return updated_object + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """ObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + class Object2p5DConstraintsMixin: """ @@ -591,6 +646,85 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + identical_slices, + kz_regularization_filter, + kz_regularization_gamma, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object2p5DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, + kz_regularization_gamma, + z_padding=1, + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + z_padding=1, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + padding=tv_denoise_pad_chambolle, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + class Object3DConstraintsMixin: """ @@ -699,6 +833,58 @@ def _object_butterworth_constraint( return current_object + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object3DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object=False + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # Positivity + if object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + class ProbeConstraintsMixin: """ @@ -892,6 +1078,60 @@ def _probe_aberration_fitting_constraint( return current_probe + def _probe_constraints( + self, + current_probe, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + **kwargs, + ): + """ProbeConstraints wrapper function""" + + # CoM corner-centering + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( + current_probe, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( + current_probe, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( + current_probe, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + return current_probe + class ProbeMixedConstraintsMixin: """ @@ -961,6 +1201,67 @@ def _probe_orthogonalization_constraint(self, current_probe): intensities_order = xp.argsort(intensities, axis=None)[::-1] return current_probe[intensities_order] + def _probe_constraints( + self, + current_probe, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + orthogonalize_probe, + **kwargs, + ): + """ProbeMixedConstraints wrapper function""" + + # CoM corner-centering + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_aberration_fitting_constraint( + current_probe[probe_idx], + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe[0] = self._probe_aperture_constraint( + current_probe[0], + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe[0] = self._probe_fourier_amplitude_constraint( + current_probe[0], + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_amplitude_constraint( + current_probe[probe_idx], + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + # Probe orthogonalization + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + return current_probe + class PositionsConstraintsMixin: """ @@ -1032,3 +1333,24 @@ def _positions_affine_transformation_constraint( current_positions = tf(initial_positions, origin=self._positions_px_com, xp=xp) return current_positions + + def _positions_constraints( + self, + current_positions, + fix_positions, + global_affine_transformation, + **kwargs, + ): + """PositionsConstraints wrapper function""" + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_positions diff --git a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py index d341ffdc9..5b4892c5e 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_tomography.py @@ -697,181 +697,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object=False - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -1051,8 +876,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) if max_batch_size is not None: diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index d218651c7..3a0f5d77b 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -529,199 +529,6 @@ def preprocess( return self - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - # object constraints - - # smoothness - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - # L1-norm pushing vacuum to zero - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - # amplitude threshold (complex) or positivity (potential) - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - # probe constraints - - # CoM corner-centering - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # Fourier amplitude (aperture) constraints - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - # Fourier phase (aberrations) fitting - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - # Real-space amplitude constraint - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - # position constraints - if not fix_positions: - # CoM centering - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - # global affine transformation - # TO-DO: generalize to higher-order basis? - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - def reconstruct( self, max_iter: int = 8, @@ -906,8 +713,8 @@ def reconstruct( projection_b, projection_c, normalization_min, - step_size, max_batch_size, + step_size, ) # batching