From 0f32bd98039ef3783466533f3737ce26483470f2 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 13 Nov 2024 19:02:04 -0800 Subject: [PATCH] add 3D object support mask --- .../phase/ptychographic_constraints.py | 5 +++++ .../process/phase/ptychographic_tomography.py | 21 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 39342ab75..b0a2c2dff 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -868,6 +868,7 @@ def _object_constraints( object_positivity, shrinkage_rad, object_mask, + object_real_space_support_mask, **kwargs, ): """Object3DConstraints wrapper function""" @@ -899,6 +900,10 @@ def _object_constraints( object_mask, ) + # 3D support mask + if object_real_space_support_mask is not None: + current_object *= object_real_space_support_mask + # Positivity if object_positivity: current_object = self._object_positivity_constraint(current_object) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index 32e036bea..f0c01b43d 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -960,6 +960,7 @@ def reconstruct( detector_fourier_mask: np.ndarray = None, virtual_detector_masks: Sequence[np.ndarray] = None, probe_real_space_support_mask: np.ndarray = None, + object_real_space_support_mask: np.ndarray = None, tv_denoise: bool = True, tv_denoise_weights: float = None, tv_denoise_inner_iter=40, @@ -1159,6 +1160,11 @@ def reconstruct( if virtual_detector_masks is not None: virtual_detector_masks = xp.asarray(virtual_detector_masks).astype(xp.bool_) + if object_real_space_support_mask is not None: + object_real_space_support_mask = xp.asarray( + object_real_space_support_mask, dtype=xp.float32 + ) + # only return b/w iterations old_rot_matrix = np.eye(3) # identity @@ -1226,6 +1232,13 @@ def reconstruct( use_fourier_rotation, ) + if not collective_measurement_updates: + object_real_space_support_mask = self._rotate_zxy_volume( + object_real_space_support_mask, + rot_matrix @ old_rot_matrix.T, + use_fourier_rotation, + ) + object_sliced = self._project_sliced_object( self._object, self._num_slices ) @@ -1434,6 +1447,7 @@ def reconstruct( and self._object_fov_mask_inverse.sum() > 0 else None ), + object_real_space_support_mask=object_real_space_support_mask, tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter, @@ -1449,6 +1463,12 @@ def reconstruct( use_fourier_rotation, ) + object_real_space_support_mask = self._rotate_zxy_volume( + object_real_space_support_mask, + old_rot_matrix, + use_fourier_rotation, + ) + # object only self._object = self._object_constraints( self._object, @@ -1468,6 +1488,7 @@ def reconstruct( and self._object_fov_mask_inverse.sum() > 0 else None ), + object_real_space_support_mask=object_real_space_support_mask, tv_denoise=tv_denoise and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, tv_denoise_inner_iter=tv_denoise_inner_iter,