Skip to content

Commit

Permalink
cleaned up single-slice reconstruct
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 2, 2024
1 parent c8a2a84 commit f87a0b0
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 183 deletions.
180 changes: 180 additions & 0 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,186 @@ def _extract_vectorized_patch_indices(self):

return vectorized_patch_indices_row, vectorized_patch_indices_col

def _set_reconstruction_method_parameters(
self,
reconstruction_method,
reconstruction_parameter,
reconstruction_parameter_a,
reconstruction_parameter_b,
reconstruction_parameter_c,
step_size,
):
""""""

if reconstruction_method == "generalized-projections":
if (
reconstruction_parameter_a is None
or reconstruction_parameter_b is None
or reconstruction_parameter_c is None
):
raise ValueError(
(
"reconstruction_parameter_a/b/c must all be specified "
"when using reconstruction_method='generalized-projections'."
)
)

use_projection_scheme = True
projection_a = reconstruction_parameter_a
projection_b = reconstruction_parameter_b
projection_c = reconstruction_parameter_c
reconstruction_parameter = None
step_size = None
elif (
reconstruction_method == "DM_AP"
or reconstruction_method == "difference-map_alternating-projections"
):
if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
raise ValueError("reconstruction_parameter must be between 0-1.")

use_projection_scheme = True
projection_a = -reconstruction_parameter
projection_b = 1
projection_c = 1 + reconstruction_parameter
step_size = None
elif (
reconstruction_method == "RAAR"
or reconstruction_method == "relaxed-averaged-alternating-reflections"
):
if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
raise ValueError("reconstruction_parameter must be between 0-1.")

use_projection_scheme = True
projection_a = 1 - 2 * reconstruction_parameter
projection_b = reconstruction_parameter
projection_c = 2
step_size = None
elif (
reconstruction_method == "RRR"
or reconstruction_method == "relax-reflect-reflect"
):
if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
raise ValueError("reconstruction_parameter must be between 0-2.")

use_projection_scheme = True
projection_a = -reconstruction_parameter
projection_b = reconstruction_parameter
projection_c = 2
step_size = None
elif (
reconstruction_method == "SUPERFLIP"
or reconstruction_method == "charge-flipping"
):
use_projection_scheme = True
projection_a = 0
projection_b = 1
projection_c = 2
reconstruction_parameter = None
step_size = None
elif (
reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
):
use_projection_scheme = False
projection_a = None
projection_b = None
projection_c = None
reconstruction_parameter = None
else:
raise ValueError(
(
"reconstruction_method must be one of 'generalized-projections', "
"'DM_AP' (or 'difference-map_alternating-projections'), "
"'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
"'RRR' (or 'relax-reflect-reflect'), "
"'SUPERFLIP' (or 'charge-flipping'), "
f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
)
)

return (
use_projection_scheme,
projection_a,
projection_b,
projection_c,
reconstruction_parameter,
step_size,
)

def _report_reconstruction_summary(
self,
max_iter,
switch_object_iter,
use_projection_scheme,
reconstruction_method,
reconstruction_parameter,
projection_a,
projection_b,
projection_c,
normalization_min,
max_batch_size,
step_size,
):
""" """

# object type
if switch_object_iter > max_iter:
first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, "
else:
switch_object_type = (
"complex" if self._object_type == "potential" else "potential"
)
first_line = (
f"Performing {switch_object_iter} iterations using a {self._object_type} object type and "
f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, "
)

# stochastic gradient descent
if max_batch_size is not None:
if use_projection_scheme:
raise ValueError(
(
"Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
"Use reconstruction_method='GD' or set max_batch_size=None."
)
)
else:
print(
(
first_line + f"with the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and step _size: {step_size}, "
f"in batches of max {max_batch_size} measurements."
)
)

else:
# named projection set method
if reconstruction_parameter is not None:
print(
(
first_line + f"with the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
)
)

# generalized projections (or the even more rare charge-flipping)
elif projection_a is not None:
print(
(
first_line + f"with the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and (a,b,c): "
f"{projection_a, projection_b, projection_c}."
)
)

# gradient descent
else:
print(
(
first_line + f"with the {reconstruction_method} algorithm, "
f"with normalization_min: {normalization_min} and step _size: {step_size}."
)
)

def _position_correction(
self,
relevant_object,
Expand Down
42 changes: 42 additions & 0 deletions py4DSTEM/process/phase/iterative_ptychographic_methods.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Sequence, Tuple

import matplotlib.pyplot as plt
Expand All @@ -20,6 +21,8 @@
except (ModuleNotFoundError, ImportError):
cp = np

warnings.simplefilter(action="always", category=UserWarning)


class ObjectNDMethodsMixin:
"""
Expand Down Expand Up @@ -1756,6 +1759,45 @@ def _adjoint(

return current_object, current_probe

def _reset_reconstruction(
self,
store_iterations,
reset,
):
""" """
if store_iterations and (not hasattr(self, "object_iterations") or reset):
self.object_iterations = []
self.probe_iterations = []

# reset can be True, False, or None (default)
if reset is True:
self.error_iterations = []
self._object = self._object_initial.copy()
self._probe = self._probe_initial.copy()
self._positions_px = self._positions_px_initial.copy()
self._object_type = self._object_type_initial
self._exit_waves = None

# delete positions affine transform
if hasattr(self, "_tf"):
del self._tf

elif reset is None:
# continued run
if hasattr(self, "error"):
warnings.warn(
(
"Continuing reconstruction from previous result. "
"Use reset=True for a fresh start."
),
UserWarning,
)

# first start
else:
self.error_iterations = []
self._exit_waves = None


class Object2p5DProbeMethodsMixin:
"""
Expand Down
Loading

0 comments on commit f87a0b0

Please sign in to comment.