Skip to content

Commit

Permalink
Merge pull request #683 from juliedactyl/phase_contrast
Browse files Browse the repository at this point in the history
Segmented and Other Geometry Detector Ptychography
  • Loading branch information
smribet authored Sep 18, 2024
2 parents 73bd240 + 2a3bd6f commit 50d0cee
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 2 deletions.
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
tv_denoise: bool = True,
tv_denoise_weights=None,
Expand Down Expand Up @@ -990,6 +991,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -1081,6 +1085,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1188,6 +1195,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
store_iterations: bool = False,
collective_measurement_updates: bool = True,
Expand Down Expand Up @@ -1310,6 +1311,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -1407,6 +1411,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1489,6 +1496,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme=use_projection_scheme,
projection_a=projection_a,
projection_b=projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def reconstruct(
fix_potential_baseline: bool = True,
vacuum_mask: np.ndarray = None,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
pure_phase_object: bool = False,
tv_denoise_chambolle: bool = True,
Expand Down Expand Up @@ -885,6 +886,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
pure_phase_object: bool, optional
Expand Down Expand Up @@ -984,6 +988,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1031,6 +1038,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/mixedstate_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def reconstruct(
fix_potential_baseline: bool = True,
vacuum_mask: np.ndarray = None,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
store_iterations: bool = False,
progress_bar: bool = True,
Expand Down Expand Up @@ -791,6 +792,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -878,6 +882,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -925,6 +932,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def reconstruct(
fix_potential_baseline: bool = True,
vacuum_mask: np.ndarray = None,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
pure_phase_object: bool = False,
tv_denoise_chambolle: bool = True,
Expand Down Expand Up @@ -857,6 +858,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
pure_phase_object: bool, optional
Expand Down Expand Up @@ -960,6 +964,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1007,6 +1014,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
55 changes: 53 additions & 2 deletions py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,7 +1679,9 @@ def cross_correlate_amplitudes_to_probe_aperture(

return self

def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask):
def _gradient_descent_fourier_projection(
self, amplitudes, overlap, fourier_mask, virtual_detector_masks
):
"""
Ptychographic fourier projection method for GD method.
Expand All @@ -1692,6 +1694,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
Returns
--------
Expand All @@ -1715,6 +1720,15 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, None, :, :] * virtual_detector_masks[None, :, :, :],
axis=(-1, -2),
).transpose()
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[:, mask] = value[:, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

Expand Down Expand Up @@ -1746,6 +1760,7 @@ def _projection_sets_fourier_projection(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand Down Expand Up @@ -1777,6 +1792,9 @@ def _projection_sets_fourier_projection(
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
Currently not implemented for projection-sets
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
projection_a: float
projection_b: float
projection_c: float
Expand All @@ -1792,6 +1810,9 @@ def _projection_sets_fourier_projection(
if fourier_mask is not None:
raise NotImplementedError()

if virtual_detector_masks is not None:
raise NotImplementedError()

xp = self._xp
projection_x = 1 - projection_a - projection_b
projection_y = 1 - projection_c
Expand Down Expand Up @@ -1849,6 +1870,7 @@ def _forward(
amplitudes,
exit_waves,
fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand All @@ -1871,6 +1893,9 @@ def _forward(
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
use_projection_scheme: bool,
If True, use generalized projection update
projection_a: float
Expand Down Expand Up @@ -1907,6 +1932,7 @@ def _forward(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand All @@ -1917,6 +1943,7 @@ def _forward(
amplitudes,
overlap,
fourier_mask,
virtual_detector_masks,
)

return shifted_probes, object_patches, overlap, exit_waves, error
Expand Down Expand Up @@ -2904,7 +2931,9 @@ def _return_farfield_amplitudes(self, fourier_overlap):
xp = self._xp
return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask):
def _gradient_descent_fourier_projection(
self, amplitudes, overlap, fourier_mask, virtual_detector_masks
):
"""
Ptychographic fourier projection method for GD method.
Expand All @@ -2917,6 +2946,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
Returns
--------
Expand All @@ -2940,8 +2972,19 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, :, None, :, :]
* virtual_detector_masks[None, None, :, :, :],
axis=(-1, -2),
).transpose(2, 0, 1)
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)

Expand All @@ -2951,6 +2994,7 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap

fourier_modified_overlap = fourier_modified_overlap - fourier_overlap

if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

Expand All @@ -2974,6 +3018,7 @@ def _projection_sets_fourier_projection(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand Down Expand Up @@ -3005,6 +3050,9 @@ def _projection_sets_fourier_projection(
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
Currently not implemented for projection sets
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
projection_a: float
projection_b: float
projection_c: float
Expand All @@ -3020,6 +3068,9 @@ def _projection_sets_fourier_projection(
if fourier_mask is not None:
raise NotImplementedError()

if virtual_detector_masks is not None:
raise NotImplementedError()

xp = self._xp
projection_x = 1 - projection_a - projection_b
projection_y = 1 - projection_c
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
tv_denoise: bool = True,
tv_denoise_weights: float = None,
Expand Down Expand Up @@ -910,6 +911,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to apply at the detector-plane for zeroing-out unreliable gradients.
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -990,6 +994,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1082,6 +1089,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
Loading

0 comments on commit 50d0cee

Please sign in to comment.