Skip to content

Commit

Permalink
cleaned up overlap tomo reconstruct, different probes per tilt
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 2, 2024
1 parent 09ec6bc commit 14c1e66
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 225 deletions.
242 changes: 56 additions & 186 deletions py4DSTEM/process/phase/iterative_overlap_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,8 @@ def preprocess(
)

self._probes_all.append(_probe)
self._probes_all_initial = _probe.copy()
self._probes_all_initial_aperture = xp.abs(xp.fft.fft2(_probe))
self._probes_all_initial.append(_probe.copy())
self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe)))

del self._probe_init

Expand Down Expand Up @@ -1034,7 +1034,7 @@ def _constraints(

def reconstruct(
self,
max_iter: int = 64,
max_iter: int = 8,
reconstruction_method: str = "gradient-descent",
reconstruction_parameter: float = 1.0,
reconstruction_parameter_a: float = None,
Expand Down Expand Up @@ -1181,186 +1181,49 @@ def reconstruct(
asnumpy = self._asnumpy
xp = self._xp

# Reconstruction method

if reconstruction_method == "generalized-projections":
if (
reconstruction_parameter_a is None
or reconstruction_parameter_b is None
or reconstruction_parameter_c is None
):
raise ValueError(
(
"reconstruction_parameter_a/b/c must all be specified "
"when using reconstruction_method='generalized-projections'."
)
)

use_projection_scheme = True
projection_a = reconstruction_parameter_a
projection_b = reconstruction_parameter_b
projection_c = reconstruction_parameter_c
step_size = None
elif (
reconstruction_method == "DM_AP"
or reconstruction_method == "difference-map_alternating-projections"
):
if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
raise ValueError("reconstruction_parameter must be between 0-1.")

use_projection_scheme = True
projection_a = -reconstruction_parameter
projection_b = 1
projection_c = 1 + reconstruction_parameter
step_size = None
elif (
reconstruction_method == "RAAR"
or reconstruction_method == "relaxed-averaged-alternating-reflections"
):
if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
raise ValueError("reconstruction_parameter must be between 0-1.")

use_projection_scheme = True
projection_a = 1 - 2 * reconstruction_parameter
projection_b = reconstruction_parameter
projection_c = 2
step_size = None
elif (
reconstruction_method == "RRR"
or reconstruction_method == "relax-reflect-reflect"
):
if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
raise ValueError("reconstruction_parameter must be between 0-2.")

use_projection_scheme = True
projection_a = -reconstruction_parameter
projection_b = reconstruction_parameter
projection_c = 2
step_size = None
elif (
reconstruction_method == "SUPERFLIP"
or reconstruction_method == "charge-flipping"
):
use_projection_scheme = True
projection_a = 0
projection_b = 1
projection_c = 2
reconstruction_parameter = None
step_size = None
elif (
reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
):
use_projection_scheme = False
projection_a = None
projection_b = None
projection_c = None
reconstruction_parameter = None
else:
raise ValueError(
(
"reconstruction_method must be one of 'generalized-projections', "
"'DM_AP' (or 'difference-map_alternating-projections'), "
"'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
"'RRR' (or 'relax-reflect-reflect'), "
"'SUPERFLIP' (or 'charge-flipping'), "
f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
)
)
# set and report reconstruction method
(
use_projection_scheme,
projection_a,
projection_b,
projection_c,
reconstruction_parameter,
step_size,
) = self._set_reconstruction_method_parameters(
reconstruction_method,
reconstruction_parameter,
reconstruction_parameter_a,
reconstruction_parameter_b,
reconstruction_parameter_c,
step_size,
)

if self._verbose:
if max_batch_size is not None:
if use_projection_scheme:
raise ValueError(
(
"Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
"Use reconstruction_method='GD' or set max_batch_size=None."
)
)
else:
print(
(
f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and step _size: {step_size}, "
f"in batches of max {max_batch_size} measurements."
)
)
else:
if reconstruction_parameter is not None:
if np.array(reconstruction_parameter).shape == (3,):
print(
(
f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
)
)
else:
print(
(
f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
)
)
else:
if step_size is not None:
print(
(
f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min}."
)
)
else:
print(
(
f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and step _size: {step_size}."
)
)

# Position Correction + Collective Updates not yet implemented
if fix_positions_iter < max_iter:
raise NotImplementedError(
"Position correction is currently incompatible with collective updates."
self._report_reconstruction_summary(
max_iter,
np.inf,
use_projection_scheme,
reconstruction_method,
reconstruction_parameter,
projection_a,
projection_b,
projection_c,
normalization_min,
step_size,
max_batch_size,
)

# Batching
# batching
shuffled_indices = np.arange(self._num_diffraction_patterns)
unshuffled_indices = np.zeros_like(shuffled_indices)

if max_batch_size is not None:
xp.random.seed(seed_random)
else:
max_batch_size = self._num_diffraction_patterns

# initialization
if store_iterations and (not hasattr(self, "object_iterations") or reset):
self.object_iterations = []
self.probe_iterations = []

if reset:
self._object = self._object_initial.copy()
self.error_iterations = []
self._probe = self._probe_initial.copy()
self._positions_px_all = self._positions_px_initial_all.copy()
if hasattr(self, "_tf"):
del self._tf

if use_projection_scheme:
self._exit_waves = [None] * self._num_tilts
else:
self._exit_waves = None
elif reset is None:
if hasattr(self, "error"):
warnings.warn(
(
"Continuing reconstruction from previous result. "
"Use reset=True for a fresh start."
),
UserWarning,
)
else:
self.error_iterations = []
if use_projection_scheme:
self._exit_waves = [None] * self._num_tilts
else:
self._exit_waves = None
self._reset_reconstruction(store_iterations, reset, use_projection_scheme)

# main loop
for a0 in tqdmnd(
Expand Down Expand Up @@ -1393,6 +1256,12 @@ def reconstruct(
object_sliced = self._project_sliced_object(
self._object, self._num_slices
)

_probe = self._probes_all[self._active_tilt_index]
_probe_initial_aperture = self._probes_all_initial_aperture[
self._active_tilt_index
]

if not use_projection_scheme:
object_sliced_old = object_sliced.copy()

Expand Down Expand Up @@ -1440,14 +1309,14 @@ def reconstruct(

# forward operator
(
propagated_probes,
shifted_probes,
object_patches,
transmitted_probes,
overlap,
self._exit_waves,
batch_error,
) = self._forward(
object_sliced,
self._probe,
_probe,
amplitudes,
self._exit_waves,
use_projection_scheme,
Expand All @@ -1457,11 +1326,11 @@ def reconstruct(
)

# adjoint operator
object_sliced, self._probe = self._adjoint(
object_sliced, _probe = self._adjoint(
object_sliced,
self._probe,
_probe,
object_patches,
propagated_probes,
shifted_probes,
self._exit_waves,
use_projection_scheme=use_projection_scheme,
step_size=step_size,
Expand All @@ -1473,8 +1342,8 @@ def reconstruct(
if a0 >= fix_positions_iter:
positions_px[start:end] = self._position_correction(
object_sliced,
self._probe,
transmitted_probes,
_probe,
overlap,
amplitudes,
self._positions_px,
positions_step_size,
Expand Down Expand Up @@ -1514,11 +1383,11 @@ def reconstruct(
if not collective_tilt_updates:
(
self._object,
self._probe,
_probe,
self._positions_px_all[start_tilt:end_tilt],
) = self._constraints(
self._object,
self._probe,
_probe,
self._positions_px_all[start_tilt:end_tilt],
fix_com=fix_com and a0 >= fix_probe_iter,
constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
Expand All @@ -1535,7 +1404,7 @@ def reconstruct(
fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
fix_probe_aperture=a0 < fix_probe_aperture_iter,
initial_probe_aperture=self._probe_initial_aperture,
initial_probe_aperture=_probe_initial_aperture,
fix_positions=a0 < fix_positions_iter,
global_affine_transformation=global_affine_transformation,
gaussian_filter=a0 < gaussian_filter_iter
Expand Down Expand Up @@ -1568,11 +1437,11 @@ def reconstruct(

(
self._object,
self._probe,
_probe,
_,
) = self._constraints(
self._object,
self._probe,
_probe,
None,
fix_com=fix_com and a0 >= fix_probe_iter,
constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
Expand All @@ -1589,7 +1458,7 @@ def reconstruct(
fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
fix_probe_aperture=a0 < fix_probe_aperture_iter,
initial_probe_aperture=self._probe_initial_aperture,
initial_probe_aperture=_probe_initial_aperture,
fix_positions=True,
global_affine_transformation=global_affine_transformation,
gaussian_filter=a0 < gaussian_filter_iter
Expand All @@ -1612,6 +1481,7 @@ def reconstruct(
)

self.error_iterations.append(error.item())

if store_iterations:
self.object_iterations.append(asnumpy(self._object.copy()))
self.probe_iterations.append(self.probe_centered)
Expand Down
Loading

0 comments on commit 14c1e66

Please sign in to comment.