Skip to content

Commit

Permalink
serial loop over virtual masks, proper normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Sep 20, 2024
1 parent c89b16d commit 51c64ad
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,13 +1721,21 @@ def _gradient_descent_fourier_projection(
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, None, :, :] * virtual_detector_masks[None, :, :, :],
axis=(-1, -2),
).transpose()
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[:, mask] = value[:, None] / xp.sum(mask)
inverse_mask = (1 - virtual_detector_masks.sum(0)).astype(xp.bool_)
old_fourier_overlap_sum = xp.sum(xp.abs(fourier_overlap) ** 2)

# serial loop to allow large number of detector masks
for mask in virtual_detector_masks:
val = xp.sum(fourier_overlap * mask, axis=(-1, -2)) / xp.sum(mask)
fourier_overlap[..., mask] = val[:, None]

fourier_overlap[..., inverse_mask] = 0.0

# normalize to avoid losing electrons
new_fourier_overlap_sum = xp.sum(xp.abs(fourier_overlap) ** 2)
fourier_overlap *= xp.sqrt(
old_fourier_overlap_sum / new_fourier_overlap_sum
)

if fourier_mask is not None:
fourier_overlap *= fourier_mask
Expand Down

0 comments on commit 51c64ad

Please sign in to comment.