Skip to content

Commit

Permalink
add 3D object support mask
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 14, 2024
1 parent 6467b16 commit 0f32bd9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions py4DSTEM/process/phase/ptychographic_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def _object_constraints(
object_positivity,
shrinkage_rad,
object_mask,
object_real_space_support_mask,
**kwargs,
):
"""Object3DConstraints wrapper function"""
Expand Down Expand Up @@ -899,6 +900,10 @@ def _object_constraints(
object_mask,
)

# 3D support mask
if object_real_space_support_mask is not None:
current_object *= object_real_space_support_mask

# Positivity
if object_positivity:
current_object = self._object_positivity_constraint(current_object)
Expand Down
21 changes: 21 additions & 0 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def reconstruct(
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
object_real_space_support_mask: np.ndarray = None,
tv_denoise: bool = True,
tv_denoise_weights: float = None,
tv_denoise_inner_iter=40,
Expand Down Expand Up @@ -1159,6 +1160,11 @@ def reconstruct(
if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks).astype(xp.bool_)

if object_real_space_support_mask is not None:
object_real_space_support_mask = xp.asarray(
object_real_space_support_mask, dtype=xp.float32
)

# only return b/w iterations
old_rot_matrix = np.eye(3) # identity

Expand Down Expand Up @@ -1226,6 +1232,13 @@ def reconstruct(
use_fourier_rotation,
)

if not collective_measurement_updates:
object_real_space_support_mask = self._rotate_zxy_volume(
object_real_space_support_mask,
rot_matrix @ old_rot_matrix.T,
use_fourier_rotation,
)

object_sliced = self._project_sliced_object(
self._object, self._num_slices
)
Expand Down Expand Up @@ -1434,6 +1447,7 @@ def reconstruct(
and self._object_fov_mask_inverse.sum() > 0
else None
),
object_real_space_support_mask=object_real_space_support_mask,
tv_denoise=tv_denoise and tv_denoise_weights is not None,
tv_denoise_weights=tv_denoise_weights,
tv_denoise_inner_iter=tv_denoise_inner_iter,
Expand All @@ -1449,6 +1463,12 @@ def reconstruct(
use_fourier_rotation,
)

object_real_space_support_mask = self._rotate_zxy_volume(
object_real_space_support_mask,
old_rot_matrix,
use_fourier_rotation,
)

# object only
self._object = self._object_constraints(
self._object,
Expand All @@ -1468,6 +1488,7 @@ def reconstruct(
and self._object_fov_mask_inverse.sum() > 0
else None
),
object_real_space_support_mask=object_real_space_support_mask,
tv_denoise=tv_denoise and tv_denoise_weights is not None,
tv_denoise_weights=tv_denoise_weights,
tv_denoise_inner_iter=tv_denoise_inner_iter,
Expand Down

0 comments on commit 0f32bd9

Please sign in to comment.