Skip to content

Commit

Permalink
small tweaks to mixed-state virtual-detector ptycho
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 15, 2024
1 parent 7259309 commit 7216112
Showing 1 changed file with 64 additions and 32 deletions.
96 changes: 64 additions & 32 deletions py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,10 +1479,10 @@ def _initialize_probe(
for i_probe in range(1, num_probes):
shift_x = xp.exp(
-2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx)
)
).astype(xp.complex64)
shift_y = xp.exp(
-2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy)
)
).astype(xp.complex64)
_probes[i_probe] = (
_probes[i_probe - 1] * shift_x[:, None] * shift_y[None]
)
Expand Down Expand Up @@ -1750,31 +1750,44 @@ def _gradient_descent_fourier_projection(
xp=xp,
)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

if virtual_detector_masks is not None:
mask_sums = virtual_detector_masks.sum((-1, -2))
inverse_mask = (1 - virtual_detector_masks.sum(0)).astype(xp.bool_)
old_fourier_overlap_sum = xp.sum(xp.abs(fourier_overlap) ** 2)

# serial loop to allow large number of detector masks
for mask in virtual_detector_masks:
val = xp.sum(fourier_overlap * mask, axis=(-1, -2)) / xp.sum(mask)
fourier_overlap[..., mask] = val[:, None]

abs_fourier_overlap = self._return_farfield_amplitudes(fourier_overlap)
old_fourier_overlap_sum = xp.sum(abs_fourier_overlap**2)
fourier_overlap[..., inverse_mask] = 0.0

# normalize to avoid losing electrons
fourier_overlap_binned = xp.full_like(
fourier_overlap, fill_value=1e-16, dtype=xp.float32
)
for mask, mask_sum in zip(virtual_detector_masks, mask_sums):
val = xp.sqrt(
xp.sum(abs_fourier_overlap**2 * mask, axis=(-1, -2)) / mask_sum
)
fourier_overlap_binned[..., mask] = val[:, None]

new_fourier_overlap_sum = xp.sum(xp.abs(fourier_overlap) ** 2)
fourier_overlap *= xp.sqrt(
fourier_overlap_binned *= xp.sqrt(
old_fourier_overlap_sum / new_fourier_overlap_sum
)
fourier_modified_overlap = (
fourier_overlap * amplitudes / fourier_overlap_binned
)
farfield_amplitudes = self._return_farfield_amplitudes(
fourier_overlap_binned
)
else:
fourier_modified_overlap = amplitudes * xp.exp(
1j * xp.angle(fourier_overlap)
)
farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)
fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap))

fourier_modified_overlap = fourier_modified_overlap - fourier_overlap

if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

Expand Down Expand Up @@ -3010,27 +3023,46 @@ def _gradient_descent_fourier_projection(
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, :, None, :, :]
* virtual_detector_masks[None, None, :, :, :],
axis=(-1, -2),
).transpose(2, 0, 1)
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)
if virtual_detector_masks is not None:
mask_sums = virtual_detector_masks.sum((-1, -2))
inverse_mask = (1 - virtual_detector_masks.sum(0)).astype(xp.bool_)
abs_fourier_overlap = self._return_farfield_amplitudes(fourier_overlap)
old_fourier_overlap_sum = xp.sum(abs_fourier_overlap**2)
fourier_overlap[..., inverse_mask] = 0.0

farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf
amplitude_modification = amplitudes / farfield_amplitudes
fourier_overlap_binned = xp.full_like(
fourier_overlap, fill_value=1e-16, dtype=xp.float32
)
for mask, mask_sum in zip(virtual_detector_masks, mask_sums):
val = xp.sqrt(
xp.sum(abs_fourier_overlap**2 * mask, axis=(-1, -2))
/ mask_sum
/ self._num_probes
)
fourier_overlap_binned[..., mask] = val[:, None, None]

fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap
abs_fourier_overlap = self._return_farfield_amplitudes(fourier_overlap)
new_fourier_overlap_sum = xp.sum(abs_fourier_overlap**2)
fourier_overlap_binned *= xp.sqrt(
old_fourier_overlap_sum / new_fourier_overlap_sum
)
fourier_modified_overlap = (
fourier_overlap * amplitudes[:, None] / fourier_overlap_binned
)
farfield_amplitudes = self._return_farfield_amplitudes(
fourier_overlap_binned
)
else:
farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf

amplitude_modification = amplitudes / farfield_amplitudes
fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap

error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)
fourier_modified_overlap = fourier_modified_overlap - fourier_overlap

if fourier_mask is not None:
Expand Down

0 comments on commit 7216112

Please sign in to comment.