Skip to content

Commit

Permalink
correctly handling collective updates constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 5, 2024
1 parent 360741a commit f23d8cc
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 199 deletions.
222 changes: 46 additions & 176 deletions py4DSTEM/process/phase/iterative_magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,131 +961,28 @@ def _gradient_descent_adjoint(

return current_object, current_probe

def _constraints(
def _object_constraints(
self,
current_object,
current_probe,
current_positions,
pure_phase_object,
fix_com,
fit_probe_aberrations,
fit_probe_aberrations_max_angular_order,
fit_probe_aberrations_max_radial_order,
constrain_probe_amplitude,
constrain_probe_amplitude_relative_radius,
constrain_probe_amplitude_relative_width,
constrain_probe_fourier_amplitude,
constrain_probe_fourier_amplitude_max_width_pixels,
constrain_probe_fourier_amplitude_constant_intensity,
fix_probe_aperture,
initial_probe_aperture,
fix_positions,
global_affine_transformation,
gaussian_filter,
gaussian_filter_sigma_e,
gaussian_filter_sigma_m,
butterworth_filter,
butterworth_order,
q_lowpass_e,
q_lowpass_m,
q_highpass_e,
q_highpass_m,
butterworth_order,
tv_denoise,
tv_denoise_weight,
tv_denoise_inner_iter,
object_positivity,
shrinkage_rad,
object_mask,
**kwargs,
):
"""
Ptychographic constraints operator.
Parameters
--------
current_object: np.ndarray
Current object estimate
current_probe: np.ndarray
Current probe estimate
current_positions: np.ndarray
Current positions estimate
pure_phase_object: bool
If True, object amplitude is set to unity
fix_com: bool
If True, probe CoM is fixed to the center
fit_probe_aberrations: bool
If True, fits the probe aberrations to a low-order expansion
fit_probe_aberrations_max_angular_order: bool
Max angular order of probe aberrations basis functions
fit_probe_aberrations_max_radial_order: bool
Max radial order of probe aberrations basis functions
constrain_probe_amplitude: bool
If True, probe amplitude is constrained by top hat function
constrain_probe_amplitude_relative_radius: float
Relative location of top-hat inflection point, between 0 and 0.5
constrain_probe_amplitude_relative_width: float
Relative width of top-hat sigmoid, between 0 and 0.5
constrain_probe_fourier_amplitude: bool
If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
constrain_probe_fourier_amplitude_max_width_pixels: float
Maximum pixel width of fitted sigmoid functions.
constrain_probe_fourier_amplitude_constant_intensity: bool
If True, the probe aperture is additionally constrained to a constant intensity.
fix_probe_aperture: bool,
If True, probe Fourier amplitude is replaced by initial probe aperture.
initial_probe_aperture: np.ndarray,
Initial probe aperture to use in replacing probe Fourier amplitude.
fix_positions: bool
If True, positions are not updated
gaussian_filter: bool
If True, applies real-space gaussian filter
gaussian_filter_sigma_e: float
Standard deviation of gaussian kernel for electrostatic object in A
gaussian_filter_sigma_m: float
Standard deviation of gaussian kernel for magnetic object in A
probe_gaussian_filter: bool
If True, applies reciprocal-space gaussian filtering on residual aberrations
probe_gaussian_filter_sigma: float
Standard deviation of gaussian kernel in A^-1
probe_gaussian_filter_fix_amplitude: bool
If True, only the probe phase is smoothed
butterworth_filter: bool
If True, applies high-pass butteworth filter
q_lowpass_e: float
Cut-off frequency in A^-1 for low-pass filtering electrostatic object
q_lowpass_m: float
Cut-off frequency in A^-1 for low-pass filtering magnetic object
q_highpass_e: float
Cut-off frequency in A^-1 for high-pass filtering electrostatic object
q_highpass_m: float
Cut-off frequency in A^-1 for high-pass filtering magnetic object
butterworth_order: float
Butterworth filter order. Smaller gives a smoother filter
tv_denoise: bool
If True, applies TV denoising on object
tv_denoise_weight: float
Denoising weight. The greater `weight`, the more denoising.
tv_denoise_inner_iter: float
Number of iterations to run in inner loop of TV denoising
warmup_iteration: bool
If True, constraints electrostatic object only
object_positivity: bool
If True, clips negative potential values
shrinkage_rad: float
Phase shift in radians to be subtracted from the potential at each iteration
object_mask: np.ndarray (boolean)
If not None, used to calculate additional shrinkage using masked-mean of object
Returns
--------
constrained_object: np.ndarray
Constrained object estimate
constrained_probe: np.ndarray
Constrained probe estimate
constrained_positions: np.ndarray
Constrained positions estimate
"""

# object constraints
"""MagneticObjectNDConstraints wrapper function"""

# smoothness
if gaussian_filter:
Expand Down Expand Up @@ -1132,56 +1029,7 @@ def _constraints(
elif object_positivity:
current_object[0] = self._object_positivity_constraint(current_object[0])

# probe constraints

# CoM corner-centering
if fix_com:
current_probe = self._probe_center_of_mass_constraint(current_probe)

# Fourier amplitude (aperture) constraints
if fix_probe_aperture:
current_probe = self._probe_aperture_constraint(
current_probe,
initial_probe_aperture,
)
elif constrain_probe_fourier_amplitude:
current_probe = self._probe_fourier_amplitude_constraint(
current_probe,
constrain_probe_fourier_amplitude_max_width_pixels,
constrain_probe_fourier_amplitude_constant_intensity,
)

# Fourier phase (aberrations) fitting
if fit_probe_aberrations:
current_probe = self._probe_aberration_fitting_constraint(
current_probe,
fit_probe_aberrations_max_angular_order,
fit_probe_aberrations_max_radial_order,
)

# Real-space amplitude constraint
if constrain_probe_amplitude:
current_probe = self._probe_amplitude_constraint(
current_probe,
constrain_probe_amplitude_relative_radius,
constrain_probe_amplitude_relative_width,
)

# position constraints
if not fix_positions:
# CoM centering
current_positions = self._positions_center_of_mass_constraint(
current_positions
)

# global affine transformation
# TO-DO: generalize to higher-order basis?
if global_affine_transformation:
current_positions = self._positions_affine_transformation_constraint(
self._positions_px_initial, current_positions
)

return current_object, current_probe, current_positions
return current_object

def reconstruct(
self,
Expand Down Expand Up @@ -1351,6 +1199,12 @@ def reconstruct(
asnumpy = self._asnumpy
xp = self._xp

if not collective_measurement_updates and self._verbose:
warnings.warn(
"Magnetic ptychography is much more robust with `collective_measurement_updates=True`.",
UserWarning,
)

# set and report reconstruction method
(
use_projection_scheme,
Expand All @@ -1370,7 +1224,7 @@ def reconstruct(

if use_projection_scheme:
raise NotImplementedError(
"Magnetic ptychography currently only implemented for gradient descent."
"Magnetic ptychography is currently only implemented for gradient descent."
)

if self._verbose:
Expand Down Expand Up @@ -1543,7 +1397,38 @@ def reconstruct(
unshuffled_indices
]

if not collective_measurement_updates:
if collective_measurement_updates:
# probe and positions
_probe = self._probe_constraints(
_probe,
fix_com=fix_com and a0 >= fix_probe_iter,
constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
and a0 >= fix_probe_iter,
constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
constrain_probe_fourier_amplitude=a0
< constrain_probe_fourier_amplitude_iter
and a0 >= fix_probe_iter,
constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
fit_probe_aberrations=a0 < fit_probe_aberrations_iter
and a0 >= fix_probe_iter,
fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
fix_probe_aperture=a0 < fix_probe_aperture_iter,
initial_probe_aperture=_probe_initial_aperture,
)

self._positions_px_all[
start_idx:end_idx
] = self._positions_constraints(
self._positions_px_all[start_idx:end_idx],
fix_positions=a0 < fix_positions_iter,
global_affine_transformation=global_affine_transformation,
)

else:
# object, probe, and positions
(
self._object,
_probe,
Expand Down Expand Up @@ -1601,24 +1486,9 @@ def reconstruct(
if collective_measurement_updates:
self._object += collective_object / self._num_measurements

self._object, _, _ = self._constraints(
# object only
self._object = self._object_constraints(
self._object,
None,
None,
fix_com=False,
constrain_probe_amplitude=False,
constrain_probe_amplitude_relative_radius=None,
constrain_probe_amplitude_relative_width=None,
constrain_probe_fourier_amplitude=False,
constrain_probe_fourier_amplitude_max_width_pixels=None,
constrain_probe_fourier_amplitude_constant_intensity=None,
fit_probe_aberrations=False,
fit_probe_aberrations_max_angular_order=None,
fit_probe_aberrations_max_radial_order=None,
fix_probe_aperture=False,
initial_probe_aperture=None,
fix_positions=True,
global_affine_transformation=None,
gaussian_filter=a0 < gaussian_filter_iter
and gaussian_filter_sigma_m is not None,
gaussian_filter_sigma_e=gaussian_filter_sigma_e,
Expand Down
58 changes: 35 additions & 23 deletions py4DSTEM/process/phase/iterative_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def reconstruct(
tv_denoise_iter=np.inf,
tv_denoise_weights=None,
tv_denoise_inner_iter=40,
collective_tilt_updates: bool = False,
collective_tilt_updates: bool = True,
store_iterations: bool = False,
progress_bar: bool = True,
reset: bool = None,
Expand Down Expand Up @@ -1045,7 +1045,38 @@ def reconstruct(
unshuffled_indices
]

if not collective_tilt_updates:
if collective_tilt_updates:
# probe and positions
_probe = self._probe_constraints(
_probe,
fix_com=fix_com and a0 >= fix_probe_iter,
constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
and a0 >= fix_probe_iter,
constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
constrain_probe_fourier_amplitude=a0
< constrain_probe_fourier_amplitude_iter
and a0 >= fix_probe_iter,
constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
fit_probe_aberrations=a0 < fit_probe_aberrations_iter
and a0 >= fix_probe_iter,
fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
fix_probe_aperture=a0 < fix_probe_aperture_iter,
initial_probe_aperture=_probe_initial_aperture,
)

self._positions_px_all[
start_tilt:end_tilt
] = self._positions_constraints(
self._positions_px_all[start_tilt:end_tilt],
fix_positions=a0 < fix_positions_iter,
global_affine_transformation=global_affine_transformation,
)

else:
# object, probe, and positions
(
self._object,
_probe,
Expand Down Expand Up @@ -1100,28 +1131,9 @@ def reconstruct(
if collective_tilt_updates:
self._object += collective_object / self._num_tilts

(
self._object,
_,
_,
) = self._constraints(
# object only
self._object = self._object_constraints(
self._object,
None,
None,
fix_com=False,
constrain_probe_amplitude=False,
constrain_probe_amplitude_relative_radius=None,
constrain_probe_amplitude_relative_width=None,
constrain_probe_fourier_amplitude=False,
constrain_probe_fourier_amplitude_max_width_pixels=None,
constrain_probe_fourier_amplitude_constant_intensity=None,
fit_probe_aberrations=False,
fit_probe_aberrations_max_angular_order=None,
fit_probe_aberrations_max_radial_order=None,
fix_probe_aperture=False,
initial_probe_aperture=None,
fix_positions=True,
global_affine_transformation=None,
gaussian_filter=a0 < gaussian_filter_iter
and gaussian_filter_sigma is not None,
gaussian_filter_sigma=gaussian_filter_sigma,
Expand Down

0 comments on commit f23d8cc

Please sign in to comment.