diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml index 4e9d16f77..a960dc2f2 100644 --- a/.github/workflows/check_install_dev.yml +++ b/.github/workflows/check_install_dev.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] # include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest diff --git a/.github/workflows/check_install_main.yml b/.github/workflows/check_install_main.yml index a276cab17..d27278ba9 100644 --- a/.github/workflows/check_install_main.yml +++ b/.github/workflows/check_install_main.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest, windows-latest, macos-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] #include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest diff --git a/.github/workflows/check_install_quick.yml b/.github/workflows/check_install_quick.yml index f83ee0b73..0d20bd759 100644 --- a/.github/workflows/check_install_quick.yml +++ b/.github/workflows/check_install_quick.yml @@ -20,7 +20,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.12"] + python-version: ["3.10", "3.12"] # Currently no public runners available for this but this or arm64 should work next time # include: # - python-version: "3.10" @@ -42,4 +42,4 @@ jobs: python -c "import py4DSTEM; print(py4DSTEM.__version__)" # - name: Check machine arch # run: | - # python -c "import platform; print(platform.machine())" \ No newline at end of file + # python -c "import platform; print(platform.machine())" diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index fb4983622..9db7895d3 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -576,7 +576,9 @@ def resample_data_diffraction( resampling_factor = np.array(output_size) / np.array(datacube.shape[-2:]) resampling_factor = np.concatenate(((1, 1), resampling_factor)) - datacube.data = zoom(datacube.data, resampling_factor, order=1) + datacube.data = zoom( + datacube.data, resampling_factor, order=1, mode="grid-wrap", grid_mode=True + ) datacube.calibration.set_Q_pixel_size( datacube.calibration.get_Q_pixel_size() / resampling_factor[2] ) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 1005a619d..ecfeaa1d2 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,15 +2,15 @@ _emd_hook = True -from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction -from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction -from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction -from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction -from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.dpc import DPC +from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography +from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography +from py4DSTEM.process.phase.mixedstate_multislice_ptychography import MixedstateMultislicePtychography +from py4DSTEM.process.phase.mixedstate_ptychography import MixedstatePtychography +from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography +from py4DSTEM.process.phase.parallax import Parallax +from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/dpc.py similarity index 85% rename from py4DSTEM/process/phase/iterative_dpc.py rename to py4DSTEM/process/phase/dpc.py index 11adc0c70..5a2210d59 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -19,12 +19,11 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.visualize.vis_special import return_scaled_histogram_ordering -warnings.simplefilter(action="always", category=UserWarning) - -class DPCReconstruction(PhaseReconstruction): +class DPC(PhaseReconstruction): """ Iterative Differential Phase Constrast Reconstruction Class. @@ -42,7 +41,11 @@ class DPCReconstruction(PhaseReconstruction): verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name """ @@ -54,24 +57,17 @@ def __init__( energy: float = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "dpc_reconstruction", ): Custom.__init__(self, name=name) - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + if storage is None: + storage = device - self._gaussian_filter = gaussian_filter - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) self.set_save_defaults() @@ -82,7 +78,6 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose - self._device = device self._preprocessed = False def to_h5(self, group): @@ -234,15 +229,18 @@ def preprocess( self, dp_mask: np.ndarray = None, padding_factor: float = 2, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, maximize_divergence: bool = False, fit_function: str = "plane", force_com_rotation: float = None, force_com_transpose: bool = None, force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + vectorized_com_calculation: bool = True, force_com_measured: Sequence[np.ndarray] = None, plot_center_of_mass: str = "default", plot_rotation: bool = True, + device: str = None, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -271,6 +269,8 @@ def preprocess( Force whether diffraction intensities need to be transposed. force_com_shifts: tuple of ndarrays (CoMx, CoMy) Force CoM fitted shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) Force CoM measured shifts plot_center_of_mass: str, optional @@ -284,7 +284,12 @@ def preprocess( self: DPCReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + storage = self._storage # set additional metadata self._dp_mask = dp_mask @@ -303,7 +308,7 @@ def preprocess( data=np.empty(force_com_measured[0].shape + (1, 1)) ) - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=False, ) @@ -316,10 +321,11 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, com_measured=force_com_measured, ) @@ -328,8 +334,6 @@ def preprocess( self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -344,11 +348,23 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + ] + self.copy_attributes_to_device(attrs, storage) + # Object Initialization padded_object_shape = np.round( np.array(self._grid_scan_shape) * padding_factor ).astype("int") self._padded_object_phase = xp.zeros(padded_object_shape, dtype=xp.float32) + if self._object_phase is not None: self._padded_object_phase[ : self._grid_scan_shape[0], : self._grid_scan_shape[1] @@ -357,20 +373,23 @@ def preprocess( self._padded_object_phase_initial = self._padded_object_phase.copy() # Fourier coordinates and operators - kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]) - ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1]) + kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]).astype( + xp.float32 + ) + ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1]).astype( + xp.float32 + ) kya, kxa = xp.meshgrid(ky, kx) + k_den = kxa**2 + kya**2 k_den[0, 0] = np.inf k_den = 1 / k_den + self._kx_op = -1j * 0.25 * kxa * k_den self._ky_op = -1j * 0.25 * kya * k_den self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -413,6 +432,7 @@ def _forward( """ xp = self._xp + asnumpy = self._asnumpy dx, dy = self._scan_sampling # centered finite-differences @@ -431,8 +451,9 @@ def _forward( obj_dx[mask_inv] = 0 obj_dy[mask_inv] = 0 - new_error = xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) / ( - xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2) + new_error = asnumpy( + xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) + / (xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2)) ) return obj_dx, obj_dy, new_error, step_size @@ -516,9 +537,9 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): constrained_object: np.ndarray Constrained object estimate """ - gaussian_filter = self._gaussian_filter - + gaussian_filter = self._scipy.ndimage.gaussian_filter gaussian_filter_sigma /= self.sampling[0] + current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object @@ -558,44 +579,13 @@ def _object_butterworth_constraint( if q_lowpass: env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - current_object_mean = xp.mean(current_object) + current_object_mean = xp.mean(current_object, axis=(-2, -1), keepdims=True) current_object -= current_object_mean current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) current_object += current_object_mean return xp.real(current_object) - def _object_anti_gridding_contraint(self, current_object): - """ - Zero outer pixels of object fft to remove gridding artifacts - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - # find indices to zero - width_x = current_object.shape[0] - width_y = current_object.shape[1] - ind_min_x = int(xp.floor(width_x / 2) - 2) - ind_max_x = int(xp.ceil(width_x / 2) + 2) - ind_min_y = int(xp.floor(width_y / 2) - 2) - ind_max_y = int(xp.ceil(width_y / 2) + 2) - - # zero pixels - object_fft = xp.fft.fft2(current_object) - object_fft[ind_min_x:ind_max_x] = 0 - object_fft[:, ind_min_y:ind_max_y] = 0 - - return xp.real(xp.fft.ifft2(object_fft)) - def _constraints( self, current_object, @@ -605,7 +595,6 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, - anti_gridding, ): """ DPC constraints operator. @@ -626,9 +615,6 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - anti_gridding: bool - If true, zero outer pixels of object fft to remove - gridding artifacts Returns -------- @@ -648,11 +634,6 @@ def _constraints( butterworth_order, ) - if anti_gridding: - current_object = self._object_anti_gridding_contraint( - current_object, - ) - return current_object def reconstruct( @@ -664,13 +645,14 @@ def reconstruct( backtrack: bool = True, progress_bar: bool = True, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - butterworth_filter_iter: int = np.inf, + gaussian_filter: bool = True, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - anti_gridding: float = True, store_iterations: bool = False, + device: str = None, + clear_fft_cache: bool = None, ): """ Performs Iterative DPC Reconstruction: @@ -693,21 +675,22 @@ def reconstruct( If True, reconstruction progress bar will be printed gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - anti_gridding: bool - If true, zero outer pixels of object fft to remove - gridding artifacts store_iterations: bool, optional If True, all reconstruction iterations will be stored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -715,18 +698,34 @@ def reconstruct( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + device = self._device asnumpy = self._asnumpy # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - if reset: + if reset is True: self.error = np.inf self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -772,8 +771,6 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - if self._verbose: - print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -788,18 +785,17 @@ def reconstruct( # constraints self._padded_object_phase = self._constraints( self._padded_object_phase, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, - anti_gridding=anti_gridding, ) self.error_iterations.append(self.error.item()) + if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -822,9 +818,7 @@ def reconstruct( ] self.object_phase = asnumpy(self._object_phase) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -846,6 +840,8 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) if plot_convergence: spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) @@ -863,10 +859,15 @@ def _visualize_last_iteration( ] ax1 = fig.add_subplot(spec[0]) - im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) + + obj, vmin, vmax = return_scaled_histogram_ordering( + self.object_phase, vmin, vmax + ) + im = ax1.imshow(obj, extent=extent, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) + ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title("Reconstructed object phase") if cbar: divider = make_axes_locatable(ax1) @@ -878,10 +879,12 @@ def _visualize_last_iteration( errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -906,7 +909,6 @@ def _visualize_all_iterations( iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations """ - if not hasattr(self, "object_phase_iterations"): raise ValueError( ( @@ -915,31 +917,41 @@ def _visualize_all_iterations( ) ) - if iterations_grid == "auto": - num_iter = len(self.error_iterations) + num_iter = len(self.object_phase_iterations) + if iterations_grid == "auto": if num_iter == 1: return self._visualize_last_iteration( + fig=fig, plot_convergence=plot_convergence, cbar=cbar, **kwargs, ) + else: iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if iterations_grid[0] * iterations_grid[1] > num_iter: + raise ValueError() + auto_figsize = ( (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) if plot_convergence else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) + figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + max_iter = num_iter - 1 total_grids = np.prod(iterations_grid) - errors = self.error_iterations - phases = self.object_phase_iterations - max_iter = len(phases) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + + errors = np.array(self.error_iterations)[-num_iter:] + objects = [self.object_phase_iterations[n] for n in grid_range] extent = [ 0, @@ -966,25 +978,30 @@ def _visualize_all_iterations( ) for n, ax in enumerate(grid): + obj, vmin_n, vmax_n = return_scaled_histogram_ordering( + objects[n], vmin=vmin, vmax=vmax + ) im = ax.imshow( - phases[grid_range[n]], + obj, extent=extent, cmap=cmap, + vmin=vmin_n, + vmax=vmax_n, **kwargs, ) + ax.set_ylabel(f"x [{self._scan_units[0]}]") ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_title(f"Iter: {grid_range[n]} phase") + if cbar: grid.cbar_axes[n].colorbar(im) - ax.set_title( - f"Iteration: {grid_range[n]}\nNMSE error: {errors[grid_range[n]]:.3e}" - ) if plot_convergence: ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(len(errors)), errors, **kwargs) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_xlabel("Iteration number") - ax2.set_ylabel("Log NMSE error") + ax2.set_ylabel("NMSE error") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -1030,6 +1047,8 @@ def visualize( **kwargs, ) + self.clear_device_mem(self._device, self._clear_fft_cache) + return self @property diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py deleted file mode 100644 index 10dc40e00..000000000 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ /dev/null @@ -1,3705 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely multislice ptychography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple, Union - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = None - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from scipy.ndimage import rotate - -warnings.simplefilter(action="always", category=UserWarning) - - -class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): - """ - Mixed-State Multislice Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (N,Sx,Sy) - Reconstructed object dimensions : (T,Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes - and (Px,Py) is the padded-object size we position our ROI around in - each of the T slices. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - num_probes: int, optional - Number of mixed-state probes - num_slices: int - Number of slices to use in the forward model - slice_thicknesses: float or Sequence[float] - Slice thicknesses in angstroms. If float, all slices are assigned the same thickness - datacube: DataCube, optional - Input 4D diffraction pattern intensities - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - middle_focus: bool - if True, adds half the sample thickness to the defocus - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") - - def __init__( - self, - energy: float, - num_slices: int, - slice_thicknesses: Union[float, Sequence[float]], - num_probes: int = None, - datacube: DataCube = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - theta_x: float = 0, - theta_y: float = 0, - middle_focus: bool = False, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "multi-slice_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): - if num_probes is None: - raise ValueError( - ( - "If initial_probe_guess is None, or a ComplexProbe object, " - "num_probes must be specified." - ) - ) - else: - if len(initial_probe_guess.shape) != 3: - raise ValueError( - "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." - ) - num_probes = initial_probe_guess.shape[0] - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - if np.isscalar(slice_thicknesses): - mean_slice_thickness = slice_thicknesses - else: - mean_slice_thickness = np.mean(slice_thicknesses) - - if middle_focus: - if "defocus" in kwargs: - kwargs["defocus"] += mean_slice_thickness * num_slices / 2 - elif "C10" in kwargs: - kwargs["C10"] -= mean_slice_thickness * num_slices / 2 - elif polar_parameters is not None and "defocus" in polar_parameters: - polar_parameters["defocus"] = ( - polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 - ) - elif polar_parameters is not None and "C10" in polar_parameters: - polar_parameters["C10"] = ( - polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 - ) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - slice_thicknesses = np.array(slice_thicknesses) - if slice_thicknesses.shape == (): - slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) - elif slice_thicknesses.shape[0] != (num_slices - 1): - raise ValueError( - ( - f"slice_thicknesses must have length {num_slices - 1}, " - f"not {slice_thicknesses.shape[0]}." - ) - ) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._positions_mask = positions_mask - self._object_padding_px = object_padding_px - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_probes = num_probes - self._num_slices = num_slices - self._slice_thicknesses = slice_thicknesses - self._theta_x = theta_x - self._theta_y = theta_y - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (T,Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: MixedstateMultislicePtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None or isinstance(self._probe, ComplexProbe): - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - _probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - else: - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - _probe = self._probe._array - else: - self._probe._xp = xp - _probe = self._probe.build()._array - - self._probe = xp.zeros( - (self._num_probes,) + tuple(self._region_of_interest_shape), - dtype=xp.complex64, - ) - sx, sy = self._region_of_interest_shape - self._probe[0] = _probe - - # Randomly shift phase of other probes - for i_probe in range(1, self._num_probes): - shift_x = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) - shift_y = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) - self._probe[i_probe] = ( - self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = None # Doesn't really make sense for mixed-state - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - self._theta_x, - self._theta_y, - ) - - # overlaps - shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered[0], - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe[0].copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe[0] intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax2, chroma_boost=chroma_boost) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe[0] intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ] - - num_probe_positions = object_patches.shape[1] - - propagated_shape = ( - self._num_slices, - num_probe_positions, - self._num_probes, - self._region_of_interest_shape[0], - self._region_of_interest_shape[1], - ) - propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = ( - xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm - - fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves - modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf - - amplitude_modification = amplitudes / intensity_norm_projected - fourier_projected_factor *= amplitude_modification[:, None] - - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = xp.zeros_like(current_object[s]) - object_update = xp.zeros_like(current_object[s]) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 - ) - - if self._object_type == "potential": - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - ) - else: - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] - ) - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] += object_update * probe_normalization - - # back-transmit - exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = xp.zeros_like(current_object[s]) - object_update = xp.zeros_like(current_object[s]) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 - ) - - if self._object_type == "potential": - object_update += self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves_copy[:, i_probe] - ) - ) - else: - object_update += self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] = object_update * probe_normalization - - # back-transmit - exit_waves_copy *= xp.expand_dims( - xp.conj(obj), axis=1 - ) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - probe_intensity = xp.abs(current_probe[0]) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_orthogonalization_constraint(self, current_probe): - """ - Ptychographic probe-orthogonalization constraint. - Used to ensure mixed states are orthogonal to each other. - Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Orthogonalized probe estimate - """ - xp = self._xp - n_probes = self._num_probes - - # compute upper half of P* @ P - pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) - - for i in range(n_probes): - for j in range(i, n_probes): - pairwise_dot_product[i, j] = xp.sum( - current_probe[i].conj() * current_probe[j] - ) - - # compute eigenvectors (effectively cheaper way of computing V* from SVD) - _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") - current_probe = xp.tensordot(evecs.T, current_probe, axes=1) - - # sort by real-space intensity - intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) - intensities_order = xp.argsort(intensities, axis=None)[::-1] - return current_probe[intensities_order] - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - 2D Butterworth filter - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_kz_regularization_constraint( - self, current_object, kz_regularization_gamma - ): - """ - Arctan regularization filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - kz_regularization_gamma: float - Slice regularization strength - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - current_object = xp.pad( - current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" - ) - - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - - kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] - - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qz2 = qza**2 * kz_regularization_gamma**2 - qr2 = qxa**2 + qya**2 - - w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) - - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) - current_object = current_object[1:] - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_identical_slices_constraint(self, current_object): - """ - Strong regularization forcing all slices to be identical - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - object_mean = current_object.mean(0, keepdims=True) - current_object[:] = object_mean - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - orthogonalize_probe, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - orthogonalize_probe: bool - If True, probe will be orthogonalized - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, kz_regularization_gamma - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - pad_object=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - orthogonalize_probe: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - kz_regularization_filter_iter: int = np.inf, - kz_regularization_gamma: Union[float, np.ndarray] = None, - identical_slices_iter: int = 0, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - pure_phase_object_iter: int = 0, - tv_denoise_iter_chambolle=np.inf, - tv_denoise_weight_chambolle=None, - tv_denoise_pad_chambolle=True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter_iter: int, optional - Number of iterations to run using kz regularization filter - kz_regularization_gamma, float, optional - kz regularization strength - identical_slices_iter: int, optional - Number of iterations to run using identical slices - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - tv_denoise_iter_chambolle: bool - Number of iterations with TV denoisining - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: MultislicePtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - propagated_probes, - object_patches, - self._transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - self._probe[0], - self._transmitted_probes[:, 0], - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - kz_regularization_filter=a0 < kz_regularization_filter_iter - and kz_regularization_gamma is not None, - kz_regularization_gamma=kz_regularization_gamma[a0] - if kz_regularization_gamma is not None - and isinstance(kz_regularization_gamma, np.ndarray) - else kz_regularization_gamma, - identical_slices=a0 < identical_slices_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle - and tv_denoise_weight_chambolle is not None, - tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - orthogonalize_probe=orthogonalize_probe, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual[0] - else: - probe_array = self.probe_fourier[0] - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe[0], power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append( - self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) - ) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]][0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - return self - - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def show_transmitted_probe( - self, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations=False, - **kwargs, - ): - """ - Plots the min, max, and mean transmitted probe after propagation and transmission. - - Parameters - ---------- - plot_fourier_probe: boolean, optional - If True, the transmitted probes are also plotted in Fourier space - kwargs: - Passed to show_complex - """ - - xp = self._xp - asnumpy = self._asnumpy - - transmitted_probe_intensities = xp.sum( - xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) - ) - min_intensity_transmitted = self._transmitted_probes[ - xp.argmin(transmitted_probe_intensities), 0 - ] - max_intensity_transmitted = self._transmitted_probes[ - xp.argmax(transmitted_probe_intensities), 0 - ] - mean_transmitted = self._transmitted_probes[:, 0].mean(0) - probes = [ - asnumpy(self._return_centered_probe(probe)) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", - ] - - if plot_fourier_probe: - bottom_row = [ - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - probes = [probes, bottom_row] - - title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", - ] - - title = kwargs.get("title", title) - show_complex( - probes, - title=title, - **kwargs, - ) - - def show_slices( - self, - ms_object=None, - cbar: bool = True, - common_color_scale: bool = True, - padding: int = 0, - num_cols: int = 3, - show_fft: bool = False, - **kwargs, - ): - """ - Displays reconstructed slices of object - - Parameters - -------- - ms_object: nd.array, optional - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - padding: int, optional - Padding to leave uncropped - num_cols: int, optional - Number of GridSpec columns - show_fft: bool, optional - if True, plots fft of object slices - """ - - if ms_object is None: - ms_object = self._object - - rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) - if show_fft: - rotated_object = np.abs( - np.fft.fftshift( - np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) - ) - ) - rotated_shape = rotated_object.shape - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - - extent = [ - 0, - self.sampling[1] * rotated_shape[2], - self.sampling[0] * rotated_shape[1], - 0, - ] - - num_rows = np.ceil(self._num_slices / num_cols).astype("int") - wspace = 0.35 if cbar else 0.15 - - axsize = kwargs.pop("axsize", (3, 3)) - cmap = kwargs.pop("cmap", "magma") - - if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] - else: - vmax = None - vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) - - spec = GridSpec( - ncols=num_cols, - nrows=num_rows, - hspace=0.15, - wspace=wspace, - ) - - figsize = (axsize[0] * num_cols, axsize[1] * num_rows) - fig = plt.figure(figsize=figsize) - - for flat_index, obj_slice in enumerate(rotated_object): - row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) - ax = fig.add_subplot(spec[row_index, col_index]) - im = ax.imshow( - obj_slice, - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - ax.set_title(f"Slice index: {flat_index}") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if row_index < num_rows - 1: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col_index > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - spec.tight_layout(fig) - - def show_depth( - self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, - plot_line_profile: bool = False, - **kwargs, - ): - """ - Displays line profile depth section - - Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True - specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels - gaussian_filter_sigma: float (optional) - Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken - """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped - - if specify_calibrated: - x1 /= self.sampling[0] - x2 /= self.sampling[0] - y1 /= self.sampling[1] - y2 /= self.sampling[1] - - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) - - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 - - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point - - x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) - x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) - - rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - int(x1_0), - axis=1, - ) - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] - - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - else: - extent2 = [ - 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], - 0, - ] - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( - [y1 * self.sampling[0], y2 * self.sampling[1]], - [x1 * self.sampling[0], x2 * self.sampling[1]], - color="red", - ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() - - def tune_num_slices_and_thicknesses( - self, - num_slices_guess=None, - thicknesses_guess=None, - num_slices_step_size=1, - thicknesses_step_size=20, - num_slices_values=3, - num_thicknesses_values=3, - update_defocus=False, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of number of slices - and slice thicknesses. Should be run after the preprocess step. - - Parameters - ---------- - num_slices_guess: float, optional - initial starting guess for number of slices, rounds to nearest integer - if None, uses current initialized values - thicknesses_guess: float (A), optional - initial starting guess for thicknesses of slices assuming same - thickness for each slice - if None, uses current initialized values - num_slices_step_size: float, optional - size of change of number of slices for each step in parameter space - thicknesses_step_size: float (A), optional - size of change of slice thicknesses for each step in parameter space - num_slices_values: int, optional - number of number of slice values to test, must be >= 1 - num_thicknesses_values: int,optional - number of thicknesses values to test, must be >= 1 - update_defocus: bool, optional - if True, updates defocus based on estimated total thickness - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - - # calculate number of slices and thicknesses values to test - if num_slices_guess is None: - num_slices_guess = self._num_slices - if thicknesses_guess is None: - thicknesses_guess = np.mean(self._slice_thicknesses) - - if num_slices_values == 1: - num_slices_step_size = 0 - - if num_thicknesses_values == 1: - thicknesses_step_size = 0 - - num_slices = np.linspace( - num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_values, - ) - - thicknesses = np.linspace( - thicknesses_guess - - thicknesses_step_size * (num_thicknesses_values - 1) / 2, - thicknesses_guess - + thicknesses_step_size * (num_thicknesses_values - 1) / 2, - num_thicknesses_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_num_slices = self._num_slices - current_thicknesses = self._slice_thicknesses - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - current_defocus = -self._polar_parameters["C10"] - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values * 2, - height_ratios=[1, 1 / 4] * num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) - ) - else: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (slices, thickness) in enumerate( - tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") - ): - slices = int(slices) - self._num_slices = slices - self._slice_thicknesses = np.tile(thickness, slices - 1) - self._probe = None - self._object = None - if update_defocus: - defocus = current_defocus + slices / 2 * thickness - self._polar_parameters["C10"] = -defocus - - self.preprocess( - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - ) - self.reconstruct( - reset=True, - store_iterations=True if plot_convergence else False, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_slices_values, num_thicknesses_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._probe = None - self._object = None - self._num_slices = current_num_slices - self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) - self._polar_parameters["C10"] = -current_defocus - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = asnumpy(obj) - if np.iscomplexobj(obj): - obj = np.angle(obj) - - obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) - else: - projected_cropped_potential = self.object_cropped.sum(0) - - return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py deleted file mode 100644 index 880858f30..000000000 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ /dev/null @@ -1,2430 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely mixed-state ptychography. -""" - -import warnings -from typing import Mapping, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, -) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class MixedstatePtychographicReconstruction(PtychographicReconstruction): - """ - Mixed-State Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (N,Sx,Sy) - Reconstructed object dimensions : (Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes - and (Px,Py) is the padded-object size we position our ROI around in. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - datacube: DataCube - Input 4D diffraction pattern intensities - num_probes: int, optional - Number of mixed-state probes - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_probes",) - - def __init__( - self, - energy: float, - datacube: DataCube = None, - num_probes: int = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "mixed-state_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): - if num_probes is None: - raise ValueError( - ( - "If initial_probe_guess is None, or a ComplexProbe object, " - "num_probes must be specified." - ) - ) - else: - if len(initial_probe_guess.shape) != 3: - raise ValueError( - "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." - ) - num_probes = initial_probe_guess.shape[0] - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_probes = num_probes - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None or isinstance(self._probe, ComplexProbe): - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - _probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - else: - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - _probe = self._probe._array - else: - self._probe._xp = xp - _probe = self._probe.build()._array - - self._probe = xp.zeros( - (self._num_probes,) + tuple(self._region_of_interest_shape), - dtype=xp.complex64, - ) - sx, sy = self._region_of_interest_shape - self._probe[0] = _probe - - # Randomly shift phase of other probes - for i_probe in range(1, self._num_probes): - shift_x = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) - shift_y = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) - self._probe[i_probe] = ( - self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = None # Doesn't really make sense for mixed-state - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # overlaps - shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, axs = plt.subplots(1, self._num_probes + 1, figsize=figsize) - - for i in range(self._num_probes): - axs[i].imshow( - complex_probe_rgb[i], - extent=probe_extent, - ) - axs[i].set_ylabel("x [A]") - axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial probe[{i}] intensity") - - divider = make_axes_locatable(axs[i]) - cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax, chroma_boost=chroma_boost) - - axs[-1].imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - axs[-1].scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - axs[-1].set_ylabel("x [A]") - axs[-1].set_xlabel("y [A]") - axs[-1].set_xlim((extent[0], extent[1])) - axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm - - fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf - - amplitude_modification = amplitudes / intensity_norm_projected - fourier_projected_factor *= amplitude_modification[:, None] - - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = xp.zeros_like(current_object) - object_update = xp.zeros_like(current_object) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 - ) - if self._object_type == "potential": - object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - else: - object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object += object_update * probe_normalization - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = xp.zeros_like(current_object) - current_object = xp.zeros_like(current_object) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 - ) - if self._object_type == "potential": - current_object += self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - else: - current_object += self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object *= probe_normalization - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - probe_intensity = xp.abs(current_probe[0]) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_orthogonalization_constraint(self, current_probe): - """ - Ptychographic probe-orthogonalization constraint. - Used to ensure mixed states are orthogonal to each other. - Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Orthogonalized probe estimate - """ - xp = self._xp - n_probes = self._num_probes - - # compute upper half of P* @ P - pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) - - for i in range(n_probes): - for j in range(i, n_probes): - pairwise_dot_product[i, j] = xp.sum( - current_probe[i].conj() * current_probe[j] - ) - - # compute eigenvectors (effectively cheaper way of computing V* from SVD) - _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") - current_probe = xp.tensordot(evecs.T, current_probe, axes=1) - - # sort by real-space intensity - intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) - intensities_order = xp.argsort(intensities, axis=None)[::-1] - return current_probe[intensities_order] - - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - orthogonalize_probe, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray - initial probe aperture to use in replacing probe fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - orthogonalize_probe: bool - If True, probe will be orthogonalized - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, - fix_com: bool = True, - orthogonalize_probe: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - global_affine_transformation: bool = True, - constrain_position_distance: float = None, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, - tv_denoise_weight: float = None, - tv_denoise_inner_iter: float = 40, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - shifted_probes[:, 0], - overlap[:, 0], - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - orthogonalize_probe=orthogonalize_probe, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - ax = fig.add_subplot(spec[0, 1]) - - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual[0] - else: - probe_array = self.probe_fourier[0] - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe[0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append(self._crop_rotate_object_fov(obj, padding=padding)) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]][0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self - - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py deleted file mode 100644 index 39cb62fdd..000000000 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ /dev/null @@ -1,3465 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely multislice ptychography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple, Union - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from scipy.ndimage import rotate - -warnings.simplefilter(action="always", category=UserWarning) - - -class MultislicePtychographicReconstruction(PtychographicReconstruction): - """ - Multislice Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (T,Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py) is the padded-object size we position our ROI around in - each of the T slices. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - num_slices: int - Number of slices to use in the forward model - slice_thicknesses: float or Sequence[float] - Slice thicknesses in angstroms. If float, all slices are assigned the same thickness - datacube: DataCube, optional - Input 4D diffraction pattern intensities - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - middle_focus: bool - if True, adds half the sample thickness to the defocus - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_slice_thicknesses") - - def __init__( - self, - energy: float, - num_slices: int, - slice_thicknesses: Union[float, Sequence[float]], - datacube: DataCube = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - theta_x: float = 0, - theta_y: float = 0, - middle_focus: bool = False, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "multi-slice_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - if np.isscalar(slice_thicknesses): - mean_slice_thickness = slice_thicknesses - else: - mean_slice_thickness = np.mean(slice_thicknesses) - - if middle_focus: - if "defocus" in kwargs: - kwargs["defocus"] += mean_slice_thickness * num_slices / 2 - elif "C10" in kwargs: - kwargs["C10"] -= mean_slice_thickness * num_slices / 2 - elif polar_parameters is not None and "defocus" in polar_parameters: - polar_parameters["defocus"] = ( - polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 - ) - elif polar_parameters is not None and "C10" in polar_parameters: - polar_parameters["C10"] = ( - polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 - ) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - slice_thicknesses = np.array(slice_thicknesses) - if slice_thicknesses.shape == (): - slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) - elif slice_thicknesses.shape[0] != (num_slices - 1): - raise ValueError( - ( - f"slice_thicknesses must have length {num_slices - 1}, " - f"not {slice_thicknesses.shape[0]}." - ) - ) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._positions_mask = positions_mask - self._object_padding_px = object_padding_px - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_slices = num_slices - self._slice_thicknesses = slice_thicknesses - self._theta_x = theta_x - self._theta_y = theta_y - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (T,Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: MultislicePtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - self._theta_x, - self._theta_y, - ) - - # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe.copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1, chroma_boost=chroma_boost) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, - chroma_boost=chroma_boost, - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts(xp.conj(probe) * exit_waves) - * probe_normalization - ) - - # back-transmit - exit_waves *= xp.conj(obj) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.conj(probe) * exit_waves_copy - ) - * probe_normalization - ) - - # back-transmit - exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - 2D Butterworth filter - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_kz_regularization_constraint( - self, current_object, kz_regularization_gamma - ): - """ - Arctan regularization filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - kz_regularization_gamma: float - Slice regularization strength - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - current_object = xp.pad( - current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" - ) - - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - - kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] - - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qz2 = qza**2 * kz_regularization_gamma**2 - qr2 = qxa**2 + qya**2 - - w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) - - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) - current_object = current_object[1:] - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_identical_slices_constraint(self, current_object): - """ - Strong regularization forcing all slices to be identical - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - object_mean = current_object.mean(0, keepdims=True) - current_object[:] = object_mean - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, kz_regularization_gamma - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - pad_object=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - kz_regularization_filter_iter: int = np.inf, - kz_regularization_gamma: Union[float, np.ndarray] = None, - identical_slices_iter: int = 0, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - pure_phase_object_iter: int = 0, - tv_denoise_iter_chambolle=np.inf, - tv_denoise_weight_chambolle=None, - tv_denoise_pad_chambolle=True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter_iter: int, optional - Number of iterations to run using kz regularization filter - kz_regularization_gamma, float, optional - kz regularization strength - identical_slices_iter: int, optional - Number of iterations to run using identical slices - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - tv_denoise_iter_chambolle: bool - Number of iterations with TV denoisining - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: MultislicePtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - propagated_probes, - object_patches, - self._transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - self._probe, - self._transmitted_probes, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - kz_regularization_filter=a0 < kz_regularization_filter_iter - and kz_regularization_gamma is not None, - kz_regularization_gamma=kz_regularization_gamma[a0] - if kz_regularization_gamma is not None - and isinstance(kz_regularization_gamma, np.ndarray) - else kz_regularization_gamma, - identical_slices=a0 < identical_slices_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle - and tv_denoise_weight_chambolle is not None, - tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append( - self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) - ) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]], power=2, chroma_boost=chroma_boost - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - return self - - def show_transmitted_probe( - self, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations=False, - **kwargs, - ): - """ - Plots the min, max, and mean transmitted probe after propagation and transmission. - - Parameters - ---------- - plot_fourier_probe: boolean, optional - If True, the transmitted probes are also plotted in Fourier space - kwargs: - Passed to show_complex - """ - - xp = self._xp - asnumpy = self._asnumpy - - transmitted_probe_intensities = xp.sum( - xp.abs(self._transmitted_probes) ** 2, axis=(-2, -1) - ) - min_intensity_transmitted = self._transmitted_probes[ - xp.argmin(transmitted_probe_intensities) - ] - max_intensity_transmitted = self._transmitted_probes[ - xp.argmax(transmitted_probe_intensities) - ] - mean_transmitted = self._transmitted_probes.mean(0) - probes = [ - asnumpy(self._return_centered_probe(probe)) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", - ] - - if plot_fourier_probe: - bottom_row = [ - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - probes = [probes, bottom_row] - - title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", - ] - - title = kwargs.get("title", title) - show_complex( - probes, - title=title, - **kwargs, - ) - - def show_slices( - self, - ms_object=None, - cbar: bool = True, - common_color_scale: bool = True, - padding: int = 0, - num_cols: int = 3, - show_fft: bool = False, - **kwargs, - ): - """ - Displays reconstructed slices of object - - Parameters - -------- - ms_object: nd.array, optional - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - padding: int, optional - Padding to leave uncropped - num_cols: int, optional - Number of GridSpec columns - show_fft: bool, optional - if True, plots fft of object slices - """ - - if ms_object is None: - ms_object = self._object - - rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) - if show_fft: - rotated_object = np.abs( - np.fft.fftshift( - np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) - ) - ) - rotated_shape = rotated_object.shape - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - - extent = [ - 0, - self.sampling[1] * rotated_shape[2], - self.sampling[0] * rotated_shape[1], - 0, - ] - - num_rows = np.ceil(self._num_slices / num_cols).astype("int") - wspace = 0.35 if cbar else 0.15 - - axsize = kwargs.pop("axsize", (3, 3)) - cmap = kwargs.pop("cmap", "magma") - - if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] - else: - vmax = None - vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) - - spec = GridSpec( - ncols=num_cols, - nrows=num_rows, - hspace=0.15, - wspace=wspace, - ) - - figsize = (axsize[0] * num_cols, axsize[1] * num_rows) - fig = plt.figure(figsize=figsize) - - for flat_index, obj_slice in enumerate(rotated_object): - row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) - ax = fig.add_subplot(spec[row_index, col_index]) - im = ax.imshow( - obj_slice, - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - ax.set_title(f"Slice index: {flat_index}") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if row_index < num_rows - 1: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col_index > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - spec.tight_layout(fig) - - def show_depth( - self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, - plot_line_profile: bool = False, - **kwargs, - ): - """ - Displays line profile depth section - - Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True - specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels - gaussian_filter_sigma: float (optional) - Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken - """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped - - if specify_calibrated: - x1 /= self.sampling[0] - x2 /= self.sampling[0] - y1 /= self.sampling[1] - y2 /= self.sampling[1] - - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) - - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 - - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point - - x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) - x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) - - rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - -int(x1_0), - axis=1, - ) - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] - - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - else: - extent2 = [ - 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], - 0, - ] - - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( - [y1 * self.sampling[0], y2 * self.sampling[1]], - [x1 * self.sampling[0], x2 * self.sampling[1]], - color="red", - ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() - - def tune_num_slices_and_thicknesses( - self, - num_slices_guess=None, - thicknesses_guess=None, - num_slices_step_size=1, - thicknesses_step_size=20, - num_slices_values=3, - num_thicknesses_values=3, - update_defocus=False, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of number of slices - and slice thicknesses. Should be run after the preprocess step. - - Parameters - ---------- - num_slices_guess: float, optional - initial starting guess for number of slices, rounds to nearest integer - if None, uses current initialized values - thicknesses_guess: float (A), optional - initial starting guess for thicknesses of slices assuming same - thickness for each slice - if None, uses current initialized values - num_slices_step_size: float, optional - size of change of number of slices for each step in parameter space - thicknesses_step_size: float (A), optional - size of change of slice thicknesses for each step in parameter space - num_slices_values: int, optional - number of number of slice values to test, must be >= 1 - num_thicknesses_values: int,optional - number of thicknesses values to test, must be >= 1 - update_defocus: bool, optional - if True, updates defocus based on estimated total thickness - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - - # calculate number of slices and thicknesses values to test - if num_slices_guess is None: - num_slices_guess = self._num_slices - if thicknesses_guess is None: - thicknesses_guess = np.mean(self._slice_thicknesses) - - if num_slices_values == 1: - num_slices_step_size = 0 - - if num_thicknesses_values == 1: - thicknesses_step_size = 0 - - num_slices = np.linspace( - num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_values, - ) - - thicknesses = np.linspace( - thicknesses_guess - - thicknesses_step_size * (num_thicknesses_values - 1) / 2, - thicknesses_guess - + thicknesses_step_size * (num_thicknesses_values - 1) / 2, - num_thicknesses_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_num_slices = self._num_slices - current_thicknesses = self._slice_thicknesses - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - current_defocus = -self._polar_parameters["C10"] - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values * 2, - height_ratios=[1, 1 / 4] * num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) - ) - else: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (slices, thickness) in enumerate( - tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") - ): - slices = int(slices) - self._num_slices = slices - self._slice_thicknesses = np.tile(thickness, slices - 1) - self._probe = None - self._object = None - if update_defocus: - defocus = current_defocus + slices / 2 * thickness - self._polar_parameters["C10"] = -defocus - - self.preprocess( - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - ) - self.reconstruct( - reset=True, - store_iterations=True if plot_convergence else False, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_slices_values, num_thicknesses_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._probe = None - self._object = None - self._num_slices = current_num_slices - self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) - self._polar_parameters["C10"] = -current_defocus - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = asnumpy(obj) - if np.iscomplexobj(obj): - obj = np.angle(obj) - - obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) - else: - projected_cropped_potential = self.object_cropped.sum(0) - - return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py deleted file mode 100644 index 670ea5e40..000000000 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ /dev/null @@ -1,3389 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely overlap magnetic tomography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize import show -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg -from scipy.ndimage import rotate as rotate_np - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - project_vector_field_divergence, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): - """ - Overlap Magnetic Tomographic Reconstruction Class. - - List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (Px,Py,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py,Py) is the padded-object electrostatic potential volume, - where x-axis is the tilt. - - Parameters - ---------- - datacube: List of DataCubes - Input list of 4D diffraction pattern intensities for different tilts - energy: float - The electron energy of the wave functions in eV - num_slices: int - Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of (\alpha, \beta) tilt angle tuple in degrees, - with the following Euler-angle convention: - - \alpha tilt around z-axis - - \beta tilt around x-axis - - -\alpha tilt around z-axis - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py,Py) - If None, initialized to 1.0 - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: list of np.ndarray, optional - Probe positions in Å for each diffraction intensity per tilt - If None, initialized to a grid scan centered along tilt axis - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") - - def __init__( - self, - energy: float, - num_slices: int, - tilt_angles_deg: Sequence[Tuple[float, float]], - datacube: Sequence[DataCube] = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - object_type: str = "potential", - positions_mask: np.ndarray = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: Sequence[np.ndarray] = None, - verbose: bool = True, - device: str = "cpu", - name: str = "overlap-magnetic-tomographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - num_tilts = len(tilt_angles_deg) - if initial_scan_positions is None: - initial_scan_positions = [None] * num_tilts - - if object_type != "potential": - raise NotImplementedError() - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) - self._num_tilts = num_tilts - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def _project_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(input_z / output_z).astype("int") - pad_size = voxels_per_slice * output_z - input_z - - padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) - - return xp.sum( - padded_array.reshape( - ( - -1, - voxels_per_slice, - ) - + array.shape[1:] - ), - axis=1, - ) - - def _expand_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(output_z / input_z).astype("int") - remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) - - voxels_in_slice = xp.repeat(voxels_per_slice, input_z) - voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice - - normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] - return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] - - def _euler_angle_rotate_volume( - self, - volume_array, - alpha_deg, - beta_deg, - ): - """ - Rotate 3D volume using alpha, beta, gamma Euler angles according to convention: - - - \\-alpha tilt around first axis (z) - - \\beta tilt around second axis (x) - - \\alpha tilt around first axis (z) - - Note: since we store array as zxy, the x- and y-axis rotations flip sign below. - - """ - - rotate = self._rotate - volume = volume_array.copy() - - alpha_deg, beta_deg = np.mod(np.array([alpha_deg, beta_deg]) + 180, 360) - 180 - - if alpha_deg == -180: - # print(f"rotation of {-beta_deg} around x") - volume = rotate( - volume, - beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - elif alpha_deg == -90: - # print(f"rotation of {beta_deg} around y") - volume = rotate( - volume, - -beta_deg, - axes=(0, 1), - reshape=False, - order=3, - ) - elif alpha_deg == 0: - # print(f"rotation of {beta_deg} around x") - volume = rotate( - volume, - -beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - elif alpha_deg == 90: - # print(f"rotation of {-beta_deg} around y") - volume = rotate( - volume, - beta_deg, - axes=(0, 1), - reshape=False, - order=3, - ) - else: - # print(( - # f"rotation of {-alpha_deg} around z, " - # f"rotation of {beta_deg} around x, " - # f"rotation of {alpha_deg} around z." - # )) - - volume = rotate( - volume, - -alpha_deg, - axes=(1, 2), - reshape=False, - order=3, - ) - - volume = rotate( - volume, - -beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - - volume = rotate( - volume, - alpha_deg, - axes=(1, 2), - reshape=False, - order=3, - ) - - return volume - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_probe_overlaps: bool = True, - rotation_real_space_degrees: float = None, - diffraction_patterns_rotate_degrees: float = None, - diffraction_patterns_transpose: bool = None, - force_com_shifts: Sequence[float] = None, - progress_bar: bool = True, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - - Additionally, it initializes an (Px,Py, Py) array of 1.0 - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - rotation_real_space_degrees: float (degrees), optional - In plane rotation around z axis between x axis and tilt axis in - real space (forced to be in xy plane) - diffraction_patterns_rotate_degrees: float, optional - Relative rotation angle between real and reciprocal space - diffraction_patterns_transpose: bool, optional - Whether diffraction intensities need to be transposed. - force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. One tuple per tilt. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) - - if self._positions_mask.ndim == 2: - warnings.warn( - "2D `positions_mask` assumed the same for all measurements.", - UserWarning, - ) - self._positions_mask = np.tile( - self._positions_mask, (self._num_tilts, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array."), - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_tilts - - # Prepopulate various arrays - - if self._positions_mask[0] is None: - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - - num_probes_per_tilt = np.array(num_probes_per_tilt) - else: - num_probes_per_tilt = np.insert( - self._positions_mask.sum(axis=(-2, -1)), 0, 0 - ) - - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) - - self._mean_diffraction_intensity = [] - self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - - self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) - self._rotation_best_transpose = diffraction_patterns_transpose - - if force_com_shifts is None: - force_com_shifts = [None] * self._num_tilts - - for tilt_index in tqdmnd( - self._num_tilts, - desc="Preprocessing data", - unit="tilt", - disable=not progress_bar, - ): - if tilt_index == 0: - ( - self._datacube[tilt_index], - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[tilt_index], - ) - - self._amplitudes = xp.empty( - (self._num_diffraction_patterns,) + self._datacube[0].Qshape - ) - self._region_of_interest_shape = np.array( - self._amplitudes[0].shape[-2:] - ) - - else: - ( - self._datacube[tilt_index], - _, - _, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[tilt_index], - ) - - intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube[tilt_index], - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[tilt_index], - ) - - ( - self._amplitudes[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ], - mean_diffraction_intensity_temp, - ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, - crop_patterns, - self._positions_mask[tilt_index], - ) - - self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) - - del ( - intensities, - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], self._positions_mask[tilt_index] - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px_all, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - self._object = xp.zeros((4, q, p, q), dtype=xp.float32) - else: - self._object = xp.asarray(self._object, dtype=xp.float32) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - self._num_voxels = self._object.shape[1] - - # Center Probes - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) - - for tilt_index in range(self._num_tilts): - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._positions_px.copy() - - self._positions_px_initial_all = self._positions_px_all.copy() - self._positions_initial_all = self._positions_px_initial_all.copy() - self._positions_initial_all[:, 0] *= self.sampling[0] - self._positions_initial_all[:, 1] *= self.sampling[1] - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._slice_thicknesses = np.tile( - self._object_shape[1] * self.sampling[1] / self._num_slices, - self._num_slices - 1, - ) - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - ) - - # overlaps - if object_fov_mask is None: - probe_overlap_3D = xp.zeros_like(self._object[0]) - - for tilt_index in np.arange(self._num_tilts): - alpha_deg, beta_deg = self._tilt_angles_deg[tilt_index] - - probe_overlap_3D = self._euler_angle_rotate_volume( - probe_overlap_3D, - alpha_deg, - beta_deg, - ) - - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probe, self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts( - probe_intensities - ) - - probe_overlap_3D += probe_overlap[None] - - probe_overlap_3D = self._euler_angle_rotate_volume( - probe_overlap_3D, - alpha_deg, - -beta_deg, - ) - - probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) - self._object_fov_mask = asnumpy( - probe_overlap_3D > 0.25 * probe_overlap_3D.max() - ) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe.copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, - chroma_boost=chroma_boost, - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[0, :, 1], - self.positions[0, :, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection( - self, current_object_V, current_object_A_projected, current_probe - ): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - complex_object = xp.exp(1j * (current_object_V + current_object_A_projected)) - object_patches = complex_object[ - :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object_V, - current_object_A_projected, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes:np.ndarray - Prop[object^n*probe^n] - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection( - current_object_V, - current_object_A_projected, - current_probe, - ) - - if use_projection_scheme: - ( - exit_waves[self._active_tilt_index], - error, - ) = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves[self._active_tilt_index], - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - object_update = step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - current_object_V[s] += object_update - current_object_A_projected[s] += object_update - - # back-transmit - exit_waves *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object_V, current_object_A_projected, current_probe - - def _projection_sets_adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - object_update = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - current_object_V[s] = object_update - current_object_A_projected[s] = object_update - - # back-transmit - exit_waves_copy *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object_V, current_object_A_projected, current_probe - - def _adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - ( - current_object_V, - current_object_A_projected, - current_probe, - ) = self._projection_sets_adjoint( - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves[self._active_tilt_index], - normalization_min, - fix_probe, - ) - else: - ( - current_object_V, - current_object_A_projected, - current_probe, - ) = self._gradient_descent_adjoint( - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object_V, current_object_A_projected, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes at each layer - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes[-1]) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes[-1].shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - complex_object = xp.exp(1j * current_object) - - # dx - propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes - - # propagate - if s + 1 < self._num_slices: - propagated_probes = self._propagate_array( - transmitted_probes_perturbed[s], self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2( - transmitted_probes_perturbed[-1] - ) - - # dy - propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes - - # propagate - if s + 1 < self._num_slices: - propagated_probes = self._propagate_array( - transmitted_probes_perturbed[s], self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2( - transmitted_probes_perturbed[-1] - ) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - gaussian_filter = self._gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - Butterworth filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qra = xp.sqrt(qza**2 + qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) - current_object += current_object_mean - - return xp.real(current_object) - - def _divergence_free_constraint(self, vector_field): - """ - Leray projection operator - - Parameters - -------- - vector_field: np.ndarray - Current object vector as Az, Ax, Ay - - Returns - -------- - projected_vector_field: np.ndarray - Divergence-less object vector as Az, Ax, Ay - """ - xp = self._xp - - spacings = (self.sampling[1],) + self.sampling - vector_field = project_vector_field_divergence( - vector_field, spacings=spacings, xp=xp - ) - - return vector_field - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma_e, - gaussian_filter_sigma_m, - butterworth_filter, - q_lowpass_e, - q_lowpass_m, - q_highpass_e, - q_highpass_m, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - Calls _threshold_object_constraint() and _probe_center_of_mass_constraint() - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object[0] = self._object_gaussian_constraint( - current_object[0], gaussian_filter_sigma_e - ) - current_object[1] = self._object_gaussian_constraint( - current_object[1], gaussian_filter_sigma_m - ) - current_object[2] = self._object_gaussian_constraint( - current_object[2], gaussian_filter_sigma_m - ) - current_object[3] = self._object_gaussian_constraint( - current_object[3], gaussian_filter_sigma_m - ) - - if butterworth_filter: - current_object[0] = self._object_butterworth_constraint( - current_object[0], - q_lowpass_e, - q_highpass_e, - butterworth_order, - ) - current_object[1] = self._object_butterworth_constraint( - current_object[1], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - current_object[2] = self._object_butterworth_constraint( - current_object[2], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - current_object[3] = self._object_butterworth_constraint( - current_object[3], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - - elif tv_denoise: - current_object[0] = self._object_denoise_tv_pylops( - current_object[0], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[1] = self._object_denoise_tv_pylops( - current_object[1], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[2] = self._object_denoise_tv_pylops( - current_object[2], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[3] = self._object_denoise_tv_pylops( - current_object[3], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object[0] = self._object_shrinkage_constraint( - current_object[0], - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object[0] = self._object_positivity_constraint(current_object[0]) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma_e: float = None, - gaussian_filter_sigma_m: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass_e: float = None, - q_lowpass_m: float = None, - q_highpass_e: float = None, - q_highpass_m: float = None, - butterworth_order: float = 2, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool, optional - If True, forces object to be positive - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - collective_tilt_updates: bool - if True perform collective tilt updates - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: OverlapMagneticTomographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: - raise NotImplementedError( - "Position correction is currently incompatible with collective updates." - ) - - # Batching - - if max_batch_size is not None: - xp.random.seed(seed_random) - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - - if gaussian_filter_sigma_m is None: - gaussian_filter_sigma_m = gaussian_filter_sigma_e - - if q_lowpass_m is None: - q_lowpass_m = q_lowpass_e - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if collective_tilt_updates: - collective_object = xp.zeros_like(self._object) - - tilt_indices = np.arange(self._num_tilts) - np.random.shuffle(tilt_indices) - - for tilt_index in tilt_indices: - tilt_error = 0.0 - self._active_tilt_index = tilt_index - - alpha_deg, beta_deg = self._tilt_angles_deg[self._active_tilt_index] - alpha, beta = np.deg2rad([alpha_deg, beta_deg]) - - # V - self._object[0] = self._euler_angle_rotate_volume( - self._object[0], - alpha_deg, - beta_deg, - ) - - # Az - self._object[1] = self._euler_angle_rotate_volume( - self._object[1], - alpha_deg, - beta_deg, - ) - - # Ax - self._object[2] = self._euler_angle_rotate_volume( - self._object[2], - alpha_deg, - beta_deg, - ) - - # Ay - self._object[3] = self._euler_angle_rotate_volume( - self._object[3], - alpha_deg, - beta_deg, - ) - - object_A = self._object[1] * np.cos(beta) + np.sin(beta) * ( - self._object[3] * np.cos(alpha) - self._object[2] * np.sin(alpha) - ) - - object_sliced_V = self._project_sliced_object( - self._object[0], self._num_slices - ) - - object_sliced_A = self._project_sliced_object( - object_A, self._num_slices - ) - - if not use_projection_scheme: - object_sliced_old_V = object_sliced_V.copy() - object_sliced_old_A = object_sliced_A.copy() - - start_tilt = self._cum_probes_per_tilt[self._active_tilt_index] - end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1] - - num_diffraction_patterns = end_tilt - start_tilt - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is None: - current_max_batch_size = num_diffraction_patterns - else: - current_max_batch_size = max_batch_size - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_tilt:end_tilt - ].copy()[shuffled_indices] - - for start, end in generate_batches( - num_diffraction_patterns, max_batch=current_max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - amplitudes = self._amplitudes[start_tilt:end_tilt][ - shuffled_indices[start:end] - ] - - # forward operator - ( - propagated_probes, - object_patches, - transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - object_sliced_V, - object_sliced_A, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - object_sliced_V, object_sliced_A, self._probe = self._adjoint( - object_sliced_V, - object_sliced_A, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - object_sliced_V, - self._probe, - transmitted_probes, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - tilt_error += batch_error - - if not use_projection_scheme: - object_sliced_V -= object_sliced_old_V - object_sliced_A -= object_sliced_old_A - - object_update_V = self._expand_sliced_object( - object_sliced_V, self._num_voxels - ) - object_update_A = self._expand_sliced_object( - object_sliced_A, self._num_voxels - ) - - if collective_tilt_updates: - collective_object[0] += self._euler_angle_rotate_volume( - object_update_V, - alpha_deg, - -beta_deg, - ) - collective_object[1] += self._euler_angle_rotate_volume( - object_update_A * np.cos(beta), - alpha_deg, - -beta_deg, - ) - collective_object[2] -= self._euler_angle_rotate_volume( - object_update_A * np.sin(alpha) * np.sin(beta), - alpha_deg, - -beta_deg, - ) - collective_object[3] += self._euler_angle_rotate_volume( - object_update_A * np.cos(alpha) * np.sin(beta), - alpha_deg, - -beta_deg, - ) - else: - self._object[0] += object_update_V - self._object[1] += object_update_A * np.cos(beta) - self._object[2] -= object_update_A * np.sin(alpha) * np.sin(beta) - self._object[3] += object_update_A * np.cos(alpha) * np.sin(beta) - - self._object[0] = self._euler_angle_rotate_volume( - self._object[0], - alpha_deg, - -beta_deg, - ) - - self._object[1] = self._euler_angle_rotate_volume( - self._object[1], - alpha_deg, - -beta_deg, - ) - - self._object[2] = self._euler_angle_rotate_volume( - self._object[2], - alpha_deg, - -beta_deg, - ) - - self._object[3] = self._euler_angle_rotate_volume( - self._object[3], - alpha_deg, - -beta_deg, - ) - - # Normalize Error - tilt_error /= ( - self._mean_diffraction_intensity[self._active_tilt_index] - * num_diffraction_patterns - ) - error += tilt_error - - # constraints - self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[ - unshuffled_indices - ] - - if not collective_tilt_updates: - ( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - ) = self._constraints( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - # Normalize Error Over Tilts - error /= self._num_tilts - - self._object[1:] = self._divergence_free_constraint(self._object[1:]) - - if collective_tilt_updates: - self._object += collective_object / self._num_tilts - - ( - self._object, - self._probe, - _, - ) = self._constraints( - self._object, - self._probe, - None, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=True, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate_np( - asnumpy(array), angle, reshape=False, axes=(-2, -1) - ) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - cmap = kwargs.pop("cmap", "magma") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object[0], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object[0] - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - figsize = kwargs.pop("figsize", (14, 10) if cbar else (12, 10)) - cmap_e = kwargs.pop("cmap_e", "magma") - cmap_m = kwargs.pop("cmap_m", "PuOr") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj_V = self._rotate( - self._object[0], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Az = self._rotate( - self._object[1], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Ax = self._rotate( - self._object[2], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Ay = self._rotate( - self._object[3], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_V = asnumpy(rotated_3d_obj_V) - rotated_3d_obj_Az = asnumpy(rotated_3d_obj_Az) - rotated_3d_obj_Ax = asnumpy(rotated_3d_obj_Ax) - rotated_3d_obj_Ay = asnumpy(rotated_3d_obj_Ay) - else: - ( - rotated_3d_obj_V, - rotated_3d_obj_Az, - rotated_3d_obj_Ax, - rotated_3d_obj_Ay, - ) = self.object - - rotated_object_Vx = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Vy = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Vz = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Azx = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Azy = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Azz = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Axx = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Axy = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Axz = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Ayx = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Ayy = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Ayz = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_shape = rotated_object_Vx.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - arrays = [ - [ - rotated_object_Vx, - rotated_object_Axx, - rotated_object_Ayx, - rotated_object_Azx, - ], - [ - rotated_object_Vy, - rotated_object_Axy, - rotated_object_Ayy, - rotated_object_Azy, - ], - [ - rotated_object_Vz, - rotated_object_Axz, - rotated_object_Ayz, - rotated_object_Azz, - ], - ] - - titles = [ - [ - "V projected along x", - "Ax projected along x", - "Ay projected along x", - "Az projected along x", - ], - [ - "V projected along y", - "Ax projected along y", - "Ay projected along y", - "Az projected along y", - ], - [ - "V projected along z", - "Ax projected along z", - "Ay projected along z", - "Az projected along z", - ], - ] - - max_e = np.array( - [rotated_object_Vx.max(), rotated_object_Vy.max(), rotated_object_Vz.max()] - ).max() - max_m = np.array( - [ - [ - np.abs(rotated_object_Axx).max(), - np.abs(rotated_object_Ayx).max(), - np.abs(rotated_object_Azx).max(), - ], - [ - np.abs(rotated_object_Axy).max(), - np.abs(rotated_object_Ayy).max(), - np.abs(rotated_object_Azy).max(), - ], - [ - np.abs(rotated_object_Axz).max(), - np.abs(rotated_object_Ayz).max(), - np.abs(rotated_object_Azz).max(), - ], - ] - ).max() - - vmin_e = kwargs.pop("vmin_e", 0.0) - vmax_e = kwargs.pop("vmax_e", max_e) - vmin_m = kwargs.pop("vmin_m", -max_m) - vmax_m = kwargs.pop("vmax_m", max_m) - - if plot_convergence: - spec = GridSpec( - ncols=4, nrows=4, height_ratios=[4, 4, 4, 1], hspace=0.15, wspace=0.35 - ) - else: - spec = GridSpec(ncols=4, nrows=3, hspace=0.15, wspace=0.35) - - if fig is None: - fig = plt.figure(figsize=figsize) - - for sp in spec: - row, col = np.unravel_index(sp.num1, (4, 4)) - - if row < 3: - ax = fig.add_subplot(sp) - if sp.is_first_col(): - cmap = cmap_e - vmin = vmin_e - vmax = vmax_e - else: - cmap = cmap_m - vmin = vmin_m - vmax = vmax_m - - im = ax.imshow( - arrays[row][col], - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - ax.set_title(titles[row][col]) - - if row < 2: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - if plot_convergence and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - - ax = fig.add_subplot(spec[-1, :]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - plot_convergence: bool, - iterations_grid: Tuple[int, int], - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - """ - raise NotImplementedError() - - def visualize( - self, - fig=None, - cbar: bool = True, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims=(None, None), - y_lims=(None, None), - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - - Returns - -------- - self: OverlapMagneticTomographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - cbar=cbar, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - cbar=cbar, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - - return self - - def _return_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - xp = self._xp - asnumpy = self._asnumpy - - if obj is None: - obj = self._object[0] - else: - obj = xp.asarray(obj[0], dtype=xp.float32) - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - - @property - def positions(self): - """Probe positions [A]""" - - if self.angular_sampling is None: - return None - - asnumpy = self._asnumpy - positions_all = [] - for tilt_index in range(self._num_tilts): - positions = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ].copy() - positions[:, 0] *= self.sampling[0] - positions[:, 1] *= self.sampling[1] - positions_all.append(asnumpy(positions)) - - return np.asarray(positions_all) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - raise NotImplementedError() - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py deleted file mode 100644 index 749028b83..000000000 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ /dev/null @@ -1,3286 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely overlap tomography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize import show -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg -from scipy.ndimage import rotate as rotate_np - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class OverlapTomographicReconstruction(PtychographicReconstruction): - """ - Overlap Tomographic Reconstruction Class. - - List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (Px,Py,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py,Py) is the padded-object electrostatic potential volume, - where x-axis is the tilt. - - Parameters - ---------- - datacube: List of DataCubes - Input list of 4D diffraction pattern intensities - energy: float - The electron energy of the wave functions in eV - num_slices: int - Number of slices to use in the forward model - tilt_orientation_matrices: Sequence[np.ndarray] - List of orientation matrices for each tilt - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py,Py) - If None, initialized to 1.0 - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: list of np.ndarray, optional - Probe positions in Å for each diffraction intensity per tilt - If None, initialized to a grid scan centered along tilt axis - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions to ignore in reconstruction - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") - _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) - - def __init__( - self, - energy: float, - num_slices: int, - tilt_orientation_matrices: Sequence[np.ndarray], - datacube: Sequence[DataCube] = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - object_type: str = "potential", - positions_mask: np.ndarray = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: Sequence[np.ndarray] = None, - verbose: bool = True, - device: str = "cpu", - name: str = "overlap-tomographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import ( - affine_transform, - gaussian_filter, - rotate, - zoom, - ) - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - num_tilts = len(tilt_orientation_matrices) - if initial_scan_positions is None: - initial_scan_positions = [None] * num_tilts - - if object_type != "potential": - raise NotImplementedError() - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_slices = num_slices - self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) - self._num_tilts = num_tilts - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def _project_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(input_z / output_z).astype("int") - pad_size = voxels_per_slice * output_z - input_z - - padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) - - return xp.sum( - padded_array.reshape( - ( - -1, - voxels_per_slice, - ) - + array.shape[1:] - ), - axis=1, - ) - - def _expand_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(output_z / input_z).astype("int") - remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) - - voxels_in_slice = xp.repeat(voxels_per_slice, input_z) - voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice - - normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] - return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] - - def _rotate_zxy_volume( - self, - volume_array, - rot_matrix, - ): - """ """ - - xp = self._xp - affine_transform = self._affine_transform - swap_zxy_to_xyz = self._swap_zxy_to_xyz - - volume = volume_array.copy() - volume_shape = xp.asarray(volume.shape) - tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) - - in_center = (volume_shape - 1) / 2 - out_center = tf @ in_center - offset = in_center - out_center - - volume = affine_transform(volume, tf, offset=offset, order=3) - - return volume - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_probe_overlaps: bool = True, - rotation_real_space_degrees: float = None, - diffraction_patterns_rotate_degrees: float = None, - diffraction_patterns_transpose: bool = None, - force_com_shifts: Sequence[float] = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - progress_bar: bool = True, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - - Additionally, it initializes an (Px,Py, Py) array of 1.0 - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - rotation_real_space_degrees: float (degrees), optional - In plane rotation around z axis between x axis and tilt axis in - real space (forced to be in xy plane) - diffraction_patterns_rotate_degrees: float, optional - Relative rotation angle between real and reciprocal space - diffraction_patterns_transpose: bool, optional - Whether diffraction intensities need to be transposed. - force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. One tuple per tilt. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) - - if self._positions_mask.ndim == 2: - warnings.warn( - "2D `positions_mask` assumed the same for all measurements.", - UserWarning, - ) - self._positions_mask = np.tile( - self._positions_mask, (self._num_tilts, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array."), - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_tilts - - # Prepopulate various arrays - - if self._positions_mask[0] is None: - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - - num_probes_per_tilt = np.array(num_probes_per_tilt) - else: - num_probes_per_tilt = np.insert( - self._positions_mask.sum(axis=(-2, -1)), 0, 0 - ) - - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) - - self._mean_diffraction_intensity = [] - self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - - self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) - self._rotation_best_transpose = diffraction_patterns_transpose - - if force_com_shifts is None: - force_com_shifts = [None] * self._num_tilts - - for tilt_index in tqdmnd( - self._num_tilts, - desc="Preprocessing data", - unit="tilt", - disable=not progress_bar, - ): - if tilt_index == 0: - ( - self._datacube[tilt_index], - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[tilt_index], - ) - - self._amplitudes = xp.empty( - (self._num_diffraction_patterns,) + self._datacube[0].Qshape - ) - self._region_of_interest_shape = np.array( - self._amplitudes[0].shape[-2:] - ) - - else: - ( - self._datacube[tilt_index], - _, - _, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[tilt_index], - ) - - intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube[tilt_index], - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[tilt_index], - ) - - ( - self._amplitudes[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ], - mean_diffraction_intensity_temp, - ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, - crop_patterns, - self._positions_mask[tilt_index], - ) - - self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) - - del ( - intensities, - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], self._positions_mask[tilt_index] - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px_all, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - self._object = xp.zeros((q, p, q), dtype=xp.float32) - else: - self._object = xp.asarray(self._object, dtype=xp.float32) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - self._num_voxels = self._object.shape[0] - - # Center Probes - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) - - for tilt_index in range(self._num_tilts): - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._positions_px.copy() - - self._positions_px_initial_all = self._positions_px_all.copy() - self._positions_initial_all = self._positions_px_initial_all.copy() - self._positions_initial_all[:, 0] *= self.sampling[0] - self._positions_initial_all[:, 1] *= self.sampling[1] - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._slice_thicknesses = np.tile( - self._object_shape[1] * self.sampling[1] / self._num_slices, - self._num_slices - 1, - ) - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - ) - - # overlaps - if object_fov_mask is None: - probe_overlap_3D = xp.zeros_like(self._object) - old_rot_matrix = np.eye(3) # identity - - for tilt_index in np.arange(self._num_tilts): - rot_matrix = self._tilt_orientation_matrices[tilt_index] - - probe_overlap_3D = self._rotate_zxy_volume( - probe_overlap_3D, - rot_matrix @ old_rot_matrix.T, - ) - - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probe, self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts( - probe_intensities - ) - - probe_overlap_3D += probe_overlap[None] - old_rot_matrix = rot_matrix - - probe_overlap_3D = self._rotate_zxy_volume( - probe_overlap_3D, - old_rot_matrix.T, - ) - - probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) - self._object_fov_mask = asnumpy( - probe_overlap_3D > 0.25 * probe_overlap_3D.max() - ) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe.copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, - chroma_boost=chroma_boost, - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[0, :, 1], - self.positions[0, :, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - complex_object = xp.exp(1j * current_object) - object_patches = complex_object[ - :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes:np.ndarray - Prop[object^n*probe^n] - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - ( - exit_waves[self._active_tilt_index], - error, - ) = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves[self._active_tilt_index], - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - # back-transmit - exit_waves *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) - ) - * probe_normalization - ) - - # back-transmit - exit_waves_copy *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves[self._active_tilt_index], - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - complex_object = xp.exp(1j * current_object) - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - gaussian_filter = self._gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - Butterworth filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qra = xp.sqrt(qza**2 + qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) - current_object += current_object_mean - return xp.real(current_object) - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool, optional - If True, forces object to be positive - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - collective_tilt_updates: bool - if True perform collective tilt updates - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: - raise NotImplementedError( - "Position correction is currently incompatible with collective updates." - ) - - # Batching - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if collective_tilt_updates: - collective_object = xp.zeros_like(self._object) - - tilt_indices = np.arange(self._num_tilts) - np.random.shuffle(tilt_indices) - - old_rot_matrix = np.eye(3) # identity - - for tilt_index in tilt_indices: - self._active_tilt_index = tilt_index - - tilt_error = 0.0 - - rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] - self._object = self._rotate_zxy_volume( - self._object, - rot_matrix @ old_rot_matrix.T, - ) - - object_sliced = self._project_sliced_object( - self._object, self._num_slices - ) - if not use_projection_scheme: - object_sliced_old = object_sliced.copy() - - start_tilt = self._cum_probes_per_tilt[self._active_tilt_index] - end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1] - - num_diffraction_patterns = end_tilt - start_tilt - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_tilt:end_tilt - ].copy()[shuffled_indices] - - for start, end in generate_batches( - num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - amplitudes = self._amplitudes[start_tilt:end_tilt][ - shuffled_indices[start:end] - ] - - # forward operator - ( - propagated_probes, - object_patches, - transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - object_sliced, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - object_sliced, self._probe = self._adjoint( - object_sliced, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - object_sliced, - self._probe, - transmitted_probes, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - tilt_error += batch_error - - if not use_projection_scheme: - object_sliced -= object_sliced_old - - object_update = self._expand_sliced_object( - object_sliced, self._num_voxels - ) - - if collective_tilt_updates: - collective_object += self._rotate_zxy_volume( - object_update, rot_matrix.T - ) - else: - self._object += object_update - - old_rot_matrix = rot_matrix - - # Normalize Error - tilt_error /= ( - self._mean_diffraction_intensity[self._active_tilt_index] - * num_diffraction_patterns - ) - error += tilt_error - - # constraints - self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[ - unshuffled_indices - ] - - if not collective_tilt_updates: - ( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - ) = self._constraints( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) - - # Normalize Error Over Tilts - error /= self._num_tilts - - if collective_tilt_updates: - self._object += collective_object / self._num_tilts - - ( - self._object, - self._probe, - _, - ) = self._constraints( - self._object, - self._probe, - None, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=True, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate_np( - asnumpy(array), angle, reshape=False, axes=(-2, -1) - ) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - cmap = kwargs.pop("cmap", "magma") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - asnumpy = self._asnumpy - - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed object projection") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - ax_cb, - chroma_boost=chroma_boost, - ) - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed object projection") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - if projection_angle_deg is not None: - objects = [ - self._crop_rotate_object_manually( - rotate_np( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ).sum(0), - angle=None, - x_lims=x_lims, - y_lims=y_lims, - ) - for obj in self.object_iterations - ] - else: - objects = [ - self._crop_rotate_object_manually( - obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - for obj in self.object_iterations - ] - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) if plot_probe else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} Object") - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims=(None, None), - y_lims=(None, None), - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - - return self - - def _return_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - xp = self._xp - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - else: - obj = xp.asarray(obj, dtype=xp.float32) - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - - @property - def positions(self): - """Probe positions [A]""" - - if self.angular_sampling is None: - return None - - asnumpy = self._asnumpy - positions_all = [] - for tilt_index in range(self._num_tilts): - positions = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ].copy() - positions[:, 0] *= self.sampling[0] - positions[:, 1] *= self.sampling[1] - positions_all.append(asnumpy(positions)) - - return np.asarray(positions_all) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - raise NotImplementedError() - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py deleted file mode 100644 index 59bf61da2..000000000 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ /dev/null @@ -1,647 +0,0 @@ -import warnings - -import numpy as np -from py4DSTEM.process.phase.utils import ( - array_slice, - estimate_global_transformation_ransac, - fft_shift, - fit_aberration_surface, - regularize_probe_amplitude, -) -from py4DSTEM.process.utils import get_CoM - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception - - -class PtychographicConstraints: - """ - Container class for PtychographicReconstruction methods. - """ - - def _object_threshold_constraint(self, current_object, pure_phase_object): - """ - Ptychographic threshold constraint. - Used for avoiding the scaling ambiguity between probe and object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - pure_phase_object: bool - If True, object amplitude is set to unity - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - phase = xp.angle(current_object) - - if pure_phase_object: - amplitude = 1.0 - else: - amplitude = xp.minimum(xp.abs(current_object), 1.0) - - return amplitude * xp.exp(1.0j * phase) - - def _object_shrinkage_constraint(self, current_object, shrinkage_rad, object_mask): - """ - Ptychographic shrinkage constraint. - Used to ensure electrostatic potential is positive. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - if self._object_type == "complex": - phase = xp.angle(current_object) - amp = xp.abs(current_object) - - if object_mask is not None: - shrinkage_rad += phase[..., object_mask].mean() - - phase -= shrinkage_rad - - current_object = amp * xp.exp(1.0j * phase) - else: - if object_mask is not None: - shrinkage_rad += current_object[..., object_mask].mean() - - current_object -= shrinkage_rad - - return current_object - - def _object_positivity_constraint(self, current_object): - """ - Ptychographic positivity constraint. - Used to ensure potential is positive. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - return xp.maximum(current_object, 0.0) - - def _object_gaussian_constraint( - self, current_object, gaussian_filter_sigma, pure_phase_object - ): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - pure_phase_object: bool - If True, gaussian blur performed on phase only - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - gaussian_filter = self._gaussian_filter - gaussian_filter_sigma /= self.sampling[0] - - if pure_phase_object: - phase = xp.angle(current_object) - phase = gaussian_filter(phase, gaussian_filter_sigma) - current_object = xp.exp(1.0j * phase) - else: - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ): - """ - Ptychographic butterworth filter. - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1]) - - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weight, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weight : float - Denoising weight. The greater `weight`, the more denoising (at - the expense of fidelity to `input`). - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny) - xy_laplacian = pylops.Laplacian( - (nx, ny), axes=(0, 1), edge=False, kind="backward" - ) - - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weight], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - current_object_tv = current_object_tv.reshape(current_object.shape) - - return current_object_tv - - def _object_denoise_tv_chambolle( - self, - current_object, - weight, - axis, - pad_object, - eps=2.0e-4, - max_num_iter=200, - scaling=None, - ): - """ - Perform total-variation denoising on n-dimensional images. - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weight : float, optional - Denoising weight. The greater `weight`, the more denoising (at - the expense of fidelity to `input`). - axis: int or tuple - Axis for denoising, if None uses all axes - pad_object: bool - if True, pads object with zeros along axes of blurring - eps : float, optional - Relative difference of the value of the cost function that determines - the stop criterion. The algorithm stops when: - - (E_(n-1) - E_n) < eps * E_0 - - max_num_iter : int, optional - Maximal number of iterations used for the optimization. - scaling : tuple, optional - Scale weight of tv denoise on different axes - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - Notes - ----- - Rudin, Osher and Fatemi algorithm. - Adapted skimage.restoration.denoise_tv_chambolle. - """ - xp = self._xp - if xp.iscomplexobj(current_object): - updated_object = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - else: - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) - else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - p = xp.zeros( - (current_object.ndim,) + current_object.shape, - dtype=current_object.dtype, - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: - E_previous = E - i += 1 - - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] - updated_object = ( - updated_object / xp.sum(updated_object) * current_object_sum - ) - - return updated_object - - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - - probe_intensity = xp.abs(current_probe) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_amplitude_constraint( - self, current_probe, relative_radius, relative_width - ): - """ - Ptychographic top-hat filtering of probe. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - erf = self._erf - - probe_intensity = xp.abs(current_probe) ** 2 - current_probe_sum = xp.sum(probe_intensity) - - X = xp.fft.fftfreq(current_probe.shape[0])[:, None] - Y = xp.fft.fftfreq(current_probe.shape[1])[None] - r = xp.hypot(X, Y) - relative_radius - - sigma = np.sqrt(np.pi) / relative_width - tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) - - updated_probe = current_probe * tophat_mask - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - - return updated_probe * normalization - - def _probe_fourier_amplitude_constraint( - self, - current_probe, - width_max_pixels, - enforce_constant_intensity, - ): - """ - Ptychographic top-hat filtering of Fourier probe. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - threshold: np.ndarray - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where 1 uses the maximum amplitude to threshold. - relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - asnumpy = self._asnumpy - - current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) - current_probe_fft = xp.fft.fft2(current_probe) - - updated_probe_fft, _, _, _ = regularize_probe_amplitude( - asnumpy(current_probe_fft), - width_max_pixels=width_max_pixels, - nearest_angular_neighbor_averaging=5, - enforce_constant_intensity=enforce_constant_intensity, - corner_centered=True, - ) - - updated_probe_fft = xp.asarray(updated_probe_fft) - updated_probe = xp.fft.ifft2(updated_probe_fft) - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - - return updated_probe * normalization - - def _probe_aperture_constraint( - self, - current_probe, - initial_probe_aperture, - ): - """ - Ptychographic constraint to fix Fourier amplitude to initial aperture. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - - current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) - current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) - - updated_probe = xp.fft.ifft2( - xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture - ) - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - - return updated_probe * normalization - - def _probe_aberration_fitting_constraint( - self, - current_probe, - max_angular_order, - max_radial_order, - ): - """ - Ptychographic probe smoothing constraint. - Removes/adds known (initialization) aberrations before/after smoothing. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - fix_amplitude: bool - If True, only the phase is smoothed - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - - xp = self._xp - - fourier_probe = xp.fft.fft2(current_probe) - fourier_probe_abs = xp.abs(fourier_probe) - sampling = self.sampling - energy = self._energy - - fitted_angle, _ = fit_aberration_surface( - fourier_probe, - sampling, - energy, - max_angular_order, - max_radial_order, - xp=xp, - ) - - fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) - current_probe = xp.fft.ifft2(fourier_probe) - - return current_probe - - def _positions_center_of_mass_constraint(self, current_positions): - """ - Ptychographic position center of mass constraint. - Additionally updates vectorized indices used in _overlap_projection. - - Parameters - ---------- - current_positions: np.ndarray - Current positions estimate - - Returns - -------- - constrained_positions: np.ndarray - CoM constrained positions estimate - """ - xp = self._xp - - current_positions -= xp.mean(current_positions, axis=0) - self._positions_px_com - self._positions_px_fractional = current_positions - xp.round(current_positions) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - return current_positions - - def _positions_affine_transformation_constraint( - self, initial_positions, current_positions - ): - """ - Constrains the updated positions to be an affine transformation of the initial scan positions, - composing of two scale factors, a shear, and a rotation angle. - - Uses RANSAC to estimate the global transformation robustly. - Stores the AffineTransformation in self._tf. - - Parameters - ---------- - initial_positions: np.ndarray - Initial scan positions - current_positions: np.ndarray - Current positions estimate - - Returns - ------- - constrained_positions: np.ndarray - Affine-transform constrained positions estimate - """ - - xp = self._xp - - tf, _ = estimate_global_transformation_ransac( - positions0=initial_positions, - positions1=current_positions, - origin=self._positions_px_com, - translation_allowed=True, - min_sample=self._num_diffraction_patterns // 10, - xp=xp, - ) - - self._tf = tf - current_positions = tf(initial_positions, origin=self._positions_px_com, xp=xp) - - return current_positions diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py deleted file mode 100644 index c8cc5ee3e..000000000 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ /dev/null @@ -1,3510 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely joint ptychography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg - -try: - import cupy as cp -except (ImportError, ModuleNotFoundError): - cp = np - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, -) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class SimultaneousPtychographicReconstruction(PtychographicReconstruction): - """ - Iterative Simultaneous Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed electrostatic dimensions : (Px,Py) - Reconstructed magnetic dimensions : (Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py) is the padded-object size we position our ROI around in. - - Parameters - ---------- - datacube: Sequence[DataCube] - Tuple of input 4D diffraction pattern intensities - energy: float - The electron energy of the wave functions in eV - simultaneous_measurements_mode: str, optional - One of '-+', '-0+', '0+', where -/0/+ refer to the sign of the magnetic potential - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad objects with - If None, the padding is set to half the probe ROI dimensions - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_simultaneous_measurements_mode",) - - def __init__( - self, - energy: float, - datacube: Sequence[DataCube] = None, - simultaneous_measurements_mode: str = "-+", - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - positions_mask: np.ndarray = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - object_type: str = "complex", - verbose: bool = True, - device: str = "cpu", - name: str = "simultaneous_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._simultaneous_measurements_mode = simultaneous_measurements_mode - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin', 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._simultaneous_measurements_mode == "-+": - self._sim_recon_mode = 0 - self._num_sim_measurements = 2 - if self._verbose: - print( - ( - "Magnetic vector potential sign in first meaurement assumed to be negative.\n" - "Magnetic vector potential sign in second meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 2: - raise ValueError( - f"datacube must be a set of two measurements, not length {len(self._datacube)}." - ) - if self._datacube[0].shape != self._datacube[1].shape: - raise ValueError("datacube intensities must be the same size.") - elif self._simultaneous_measurements_mode == "-0+": - self._sim_recon_mode = 1 - self._num_sim_measurements = 3 - if self._verbose: - print( - ( - "Magnetic vector potential sign in first meaurement assumed to be negative.\n" - "Magnetic vector potential assumed to be zero in second meaurement.\n" - "Magnetic vector potential sign in third meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 3: - raise ValueError( - f"datacube must be a set of three measurements, not length {len(self._datacube)}." - ) - if ( - self._datacube[0].shape != self._datacube[1].shape - or self._datacube[0].shape != self._datacube[2].shape - ): - raise ValueError("datacube intensities must be the same size.") - elif self._simultaneous_measurements_mode == "0+": - self._sim_recon_mode = 2 - self._num_sim_measurements = 2 - if self._verbose: - print( - ( - "Magnetic vector potential assumed to be zero in first meaurement.\n" - "Magnetic vector potential sign in second meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 2: - raise ValueError( - f"datacube must be a set of two measurements, not length {len(self._datacube)}." - ) - if self._datacube[0].shape != self._datacube[1].shape: - raise ValueError("datacube intensities must be the same size.") - else: - raise ValueError( - f"simultaneous_measurements_mode must be either '-+', '-0+', or '0+', not {self._simultaneous_measurements_mode}" - ) - - if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) - - if self._positions_mask.ndim == 2: - warnings.warn( - "2D `positions_mask` assumed the same for all measurements.", - UserWarning, - ) - self._positions_mask = np.tile( - self._positions_mask, (self._num_sim_measurements, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - "`positions_mask` converted to `bool` array.", - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_sim_measurements - - if force_com_shifts is None: - force_com_shifts = [None, None, None] - elif len(force_com_shifts) == self._num_sim_measurements: - force_com_shifts = list(force_com_shifts) - else: - raise ValueError( - ( - "force_com_shifts must be a sequence of tuples " - "with the same length as the datasets." - ) - ) - - # Ensure plot_center_of_mass is not in kwargs - kwargs.pop("plot_center_of_mass", None) - - # 1st measurement sets rotation angle and transposition - ( - measurement_0, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[0], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[0], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[0], - ) - - intensities_0 = self._extract_intensities_and_calibrations_from_datacube( - measurement_0, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_0, - com_measured_y_0, - com_fitted_x_0, - com_fitted_y_0, - com_normalized_x_0, - com_normalized_y_0, - ) = self._calculate_intensities_center_of_mass( - intensities_0, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[0], - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - _com_x_0, - _com_y_0, - com_x_0, - com_y_0, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_0, - com_measured_y_0, - com_normalized_x_0, - com_normalized_y_0, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - amplitudes_0, - mean_diffraction_intensity_0, - ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, - crop_patterns, - self._positions_mask[0], - ) - - # explicitly delete namescapes - del ( - intensities_0, - com_measured_x_0, - com_measured_y_0, - com_fitted_x_0, - com_fitted_y_0, - com_normalized_x_0, - com_normalized_y_0, - _com_x_0, - _com_y_0, - com_x_0, - com_y_0, - ) - - # 2nd measurement - ( - measurement_1, - _, - _, - force_com_shifts[1], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[1], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[1], - ) - - intensities_1 = self._extract_intensities_and_calibrations_from_datacube( - measurement_1, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_1, - com_measured_y_1, - com_fitted_x_1, - com_fitted_y_1, - com_normalized_x_1, - com_normalized_y_1, - ) = self._calculate_intensities_center_of_mass( - intensities_1, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[1], - ) - - ( - _, - _, - _com_x_1, - _com_y_1, - com_x_1, - com_y_1, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_1, - com_measured_y_1, - com_normalized_x_1, - com_normalized_y_1, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=np.rad2deg(self._rotation_best_rad), - force_com_transpose=self._rotation_best_transpose, - **kwargs, - ) - - ( - amplitudes_1, - mean_diffraction_intensity_1, - ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, - crop_patterns, - self._positions_mask[1], - ) - - # explicitly delete namescapes - del ( - intensities_1, - com_measured_x_1, - com_measured_y_1, - com_fitted_x_1, - com_fitted_y_1, - com_normalized_x_1, - com_normalized_y_1, - _com_x_1, - _com_y_1, - com_x_1, - com_y_1, - ) - - # Optionally, 3rd measurement - if self._num_sim_measurements == 3: - ( - measurement_2, - _, - _, - force_com_shifts[2], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[2], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[2], - ) - - intensities_2 = self._extract_intensities_and_calibrations_from_datacube( - measurement_2, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_2, - com_measured_y_2, - com_fitted_x_2, - com_fitted_y_2, - com_normalized_x_2, - com_normalized_y_2, - ) = self._calculate_intensities_center_of_mass( - intensities_2, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[2], - ) - - ( - _, - _, - _com_x_2, - _com_y_2, - com_x_2, - com_y_2, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_2, - com_measured_y_2, - com_normalized_x_2, - com_normalized_y_2, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=np.rad2deg(self._rotation_best_rad), - force_com_transpose=self._rotation_best_transpose, - **kwargs, - ) - - ( - amplitudes_2, - mean_diffraction_intensity_2, - ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, - crop_patterns, - self._positions_mask[2], - ) - - # explicitly delete namescapes - del ( - intensities_2, - com_measured_x_2, - com_measured_y_2, - com_fitted_x_2, - com_fitted_y_2, - com_normalized_x_2, - com_normalized_y_2, - _com_x_2, - _com_y_2, - com_x_2, - com_y_2, - ) - - self._amplitudes = (amplitudes_0, amplitudes_1, amplitudes_2) - self._mean_diffraction_intensity = ( - mean_diffraction_intensity_0 - + mean_diffraction_intensity_1 - + mean_diffraction_intensity_2 - ) / 3 - - del amplitudes_0, amplitudes_1, amplitudes_2 - - else: - self._amplitudes = (amplitudes_0, amplitudes_1) - self._mean_diffraction_intensity = ( - mean_diffraction_intensity_0 + mean_diffraction_intensity_1 - ) / 2 - - del amplitudes_0, amplitudes_1 - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes[0].shape[0] - self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask[0] - ) # TO-DO: generaltize to per-dataset probe positions - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - object_e = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - object_e = xp.ones((p, q), dtype=xp.complex64) - object_m = xp.zeros((p, q), dtype=xp.float32) - else: - if self._object_type == "potential": - object_e = xp.asarray(self._object[0], dtype=xp.float32) - elif self._object_type == "complex": - object_e = xp.asarray(self._object[0], dtype=xp.complex64) - object_m = xp.asarray(self._object[1], dtype=xp.float32) - - self._object = (object_e, object_m) - self._object_initial = (object_e.copy(), object_m.copy()) - self._object_type_initial = self._object_type - self._object_shape = self._object[0].shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (9, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax2.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_xlim((extent[0], extent[1])) - ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _warmup_overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - electrostatic_obj, _ = current_object - - if self._object_type == "potential": - complex_object = xp.exp(1j * electrostatic_obj) - else: - complex_object = electrostatic_obj - - electrostatic_obj_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - object_patches = (electrostatic_obj_patches, None) - overlap = (shifted_probes * electrostatic_obj_patches, None) - - return shifted_probes, object_patches, overlap - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - electrostatic_obj, magnetic_obj = current_object - - if self._object_type == "potential": - complex_object_e = xp.exp(1j * electrostatic_obj) - else: - complex_object_e = electrostatic_obj - - complex_object_m = xp.exp(1j * magnetic_obj) - - electrostatic_obj_patches = complex_object_e[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - magnetic_obj_patches = complex_object_m[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - object_patches = (electrostatic_obj_patches, magnetic_obj_patches) - - if self._sim_recon_mode == 0: - overlap_reverse = ( - shifted_probes - * electrostatic_obj_patches - * xp.conj(magnetic_obj_patches) - ) - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_reverse, overlap_forward) - elif self._sim_recon_mode == 1: - overlap_reverse = ( - shifted_probes - * electrostatic_obj_patches - * xp.conj(magnetic_obj_patches) - ) - overlap_neutral = shifted_probes * electrostatic_obj_patches - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_reverse, overlap_neutral, overlap_forward) - else: - overlap_neutral = shifted_probes * electrostatic_obj_patches - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_neutral, overlap_forward) - - return shifted_probes, object_patches, overlap - - def _warmup_gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - - fourier_overlap = xp.fft.fft2(overlap[0]) - error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes[0] * xp.exp( - 1j * xp.angle(fourier_overlap) - ) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = (modified_overlap - overlap[0],) + (None,) * ( - self._num_sim_measurements - 1 - ) - - return exit_waves, error - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - - error = 0.0 - exit_waves = [] - for amp, overl in zip(amplitudes, overlap): - fourier_overl = xp.fft.fft2(overl) - error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2) - - fourier_modified_overl = amp * xp.exp(1j * xp.angle(fourier_overl)) - modified_overl = xp.fft.ifft2(fourier_modified_overl) - - exit_waves.append(modified_overl - overl) - - error /= len(exit_waves) - exit_waves = tuple(exit_waves) - - return exit_waves, error - - def _warmup_projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - exit_wave = exit_waves[0] - - if exit_wave is None: - exit_wave = overlap[0].copy() - - fourier_overlap = xp.fft.fft2(overlap[0]) - error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap[0] + projection_y * exit_wave - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes[0] * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_wave = ( - projection_x * exit_wave - + projection_a * overlap[0] - + projection_b * projected_factor - ) - - exit_waves = (exit_wave,) + (None,) * (self._num_sim_measurements - 1) - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - error = 0.0 - _exit_waves = [] - for amp, overl, exit_wave in zip(amplitudes, overlap, exit_waves): - if exit_wave is None: - exit_wave = overl.copy() - - fourier_overl = xp.fft.fft2(overl) - error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2) - - factor_to_be_projected = projection_c * overl + projection_y * exit_wave - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amp * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - _exit_waves.append( - projection_x * exit_wave - + projection_a * overl - + projection_b * projected_factor - ) - - error /= len(_exit_waves) - exit_waves = tuple(_exit_waves) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - warmup_iteration, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - if warmup_iteration: - shifted_probes, object_patches, overlap = self._warmup_overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._warmup_projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._warmup_gradient_descent_fourier_projection( - amplitudes, overlap - ) - - else: - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _warmup_gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, _ = current_object - electrostatic_obj_patches, _ = object_patches - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(electrostatic_obj_patches) - * xp.conj(shifted_probes) - * exit_waves[0] - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves[0] - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(electrostatic_obj_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(electrostatic_obj_patches) * exit_waves[0], - axis=0, - ) - * object_normalization - ) - - return (electrostatic_obj, None), current_probe - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, magnetic_obj = current_object - probe_conj = xp.conj(shifted_probes) - - electrostatic_obj_patches, magnetic_obj_patches = object_patches - electrostatic_conj = xp.conj(electrostatic_obj_patches) - magnetic_conj = xp.conj(magnetic_obj_patches) - - probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches) - - probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( - probe_electrostatic_abs**2 - ) - probe_electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 - + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 - ) - - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2 - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) - - if self._sim_recon_mode > 0: - probe_abs = xp.abs(shifted_probes) - probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - exit_waves_reverse, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - - elif self._object_type == "complex": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - ) - / 2 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - 1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - - elif self._sim_recon_mode == 1: - exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_magnetic_normalization - ) - / 3 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_neutral - ) - ) - * probe_normalization - ) - / 3 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 3 - ) - - elif self._object_type == "complex": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - ) - / 3 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves_neutral - ) - * probe_normalization - / 3 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 3 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - 1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - - else: - exit_waves_neutral, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_neutral - ) - ) - * probe_normalization - ) - / 2 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - - elif self._object_type == "complex": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves_neutral - ) - * probe_normalization - / 2 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 3 - ) - - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - electrostatic_obj_patches * magnetic_obj_patches - ) - electrostatic_magnetic_normalization = xp.sum( - electrostatic_magnetic_abs**2, - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) - - if self._sim_recon_mode > 0: - electrostatic_abs = xp.abs(electrostatic_obj_patches) - electrostatic_normalization = xp.sum( - electrostatic_abs**2, - axis=0, - ) - electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 3 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - else: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 2 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_object = (electrostatic_obj, magnetic_obj) - - return current_object, current_probe - - def _warmup_projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, _ = current_object - electrostatic_obj_patches, _ = object_patches - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves[0] - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(electrostatic_obj_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.conj(electrostatic_obj_patches) * exit_waves[0], - axis=0, - ) - * object_normalization - ) - - return (electrostatic_obj, None), current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - xp = self._xp - - electrostatic_obj, magnetic_obj = current_object - probe_conj = xp.conj(shifted_probes) - - electrostatic_obj_patches, magnetic_obj_patches = object_patches - electrostatic_conj = xp.conj(electrostatic_obj_patches) - magnetic_conj = xp.conj(magnetic_obj_patches) - - probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches) - - probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( - probe_electrostatic_abs**2 - ) - probe_electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 - + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 - ) - - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2 - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) - - if self._sim_recon_mode > 0: - probe_abs = xp.abs(shifted_probes) - - probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - exit_waves_reverse, exit_waves_forward = exit_waves - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - / 2 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj = xp.conj( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_reverse - ) - * probe_electrostatic_normalization - / 2 - ) - - magnetic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_forward - ) - * probe_electrostatic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - / 3 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts(probe_conj * exit_waves_neutral) - * probe_normalization - / 3 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 3 - ) - - magnetic_obj = xp.conj( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_reverse - ) - * probe_electrostatic_normalization - / 2 - ) - - magnetic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_forward - ) - * probe_electrostatic_normalization - / 2 - ) - - else: - raise NotImplementedError() - - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - electrostatic_obj_patches * magnetic_obj_patches - ) - - electrostatic_magnetic_normalization = xp.sum( - (electrostatic_magnetic_abs**2), - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) - - if self._sim_recon_mode > 0: - electrostatic_abs = xp.abs(electrostatic_obj_patches) - electrostatic_normalization = xp.sum( - (electrostatic_abs**2), - axis=0, - ) - electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - current_probe = ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - current_probe = ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 3 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - else: - current_probe = ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 2 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_object = (electrostatic_obj, magnetic_obj) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - warmup_iteration: bool, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if warmup_iteration: - if use_projection_scheme: - current_object, current_probe = self._warmup_projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._warmup_gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - else: - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma_e, - gaussian_filter_sigma_m, - butterworth_filter, - q_lowpass_e, - q_lowpass_m, - q_highpass_e, - q_highpass_m, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - warmup_iteration, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - warmup_iteration: bool - If True, constraints electrostatic object only - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - electrostatic_obj, magnetic_obj = current_object - - if gaussian_filter: - electrostatic_obj = self._object_gaussian_constraint( - electrostatic_obj, gaussian_filter_sigma_e, pure_phase_object - ) - if not warmup_iteration: - magnetic_obj = self._object_gaussian_constraint( - magnetic_obj, - gaussian_filter_sigma_m, - pure_phase_object, - ) - - if butterworth_filter: - electrostatic_obj = self._object_butterworth_constraint( - electrostatic_obj, - q_lowpass_e, - q_highpass_e, - butterworth_order, - ) - if not warmup_iteration: - magnetic_obj = self._object_butterworth_constraint( - magnetic_obj, - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - - if self._object_type == "complex": - magnetic_obj = magnetic_obj.real - if tv_denoise: - electrostatic_obj = self._object_denoise_tv_pylops( - electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter - ) - - if not warmup_iteration: - magnetic_obj = self._object_denoise_tv_pylops( - magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - electrostatic_obj = self._object_shrinkage_constraint( - electrostatic_obj, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - electrostatic_obj = self._object_threshold_constraint( - electrostatic_obj, pure_phase_object - ) - elif object_positivity: - electrostatic_obj = self._object_positivity_constraint(electrostatic_obj) - - current_object = (electrostatic_obj, magnetic_obj) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, - fix_com: bool = True, - fix_probe_iter: int = 0, - warmup_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma_e: float = None, - gaussian_filter_sigma_m: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass_e: float = None, - q_lowpass_m: float = None, - q_highpass_e: float = None, - q_highpass_m: float = None, - butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, - tv_denoise_weight: float = None, - tv_denoise_inner_iter: float = 40, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - pure_phase_object_iter: float, optional - Number of iterations where object amplitude is set to unity - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if use_projection_scheme and self._sim_recon_mode == 2: - raise NotImplementedError( - "simultaneous_measurements_mode == '0+' and projection set algorithms are currently incompatible." - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = ( - self._object_initial[0].copy(), - self._object_initial[1].copy(), - ) - self._probe = self._probe_initial.copy() - self.error_iterations = [] - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = (None,) * self._num_sim_measurements - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = (None,) * self._num_sim_measurements - - if gaussian_filter_sigma_m is None: - gaussian_filter_sigma_m = gaussian_filter_sigma_e - - if q_lowpass_m is None: - q_lowpass_m = q_lowpass_e - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = (xp.exp(1j * self._object[0]), self._object[1]) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = (xp.angle(self._object[0]), self._object[1]) - - if a0 == warmup_iter: - self._object = (self._object[0], self._object_initial[1].copy()) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - amps = [] - for amplitudes in self._amplitudes: - amps.append(amplitudes[shuffled_indices[start:end]]) - amplitudes = tuple(amps) - - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - warmup_iteration=a0 < warmup_iter, - use_projection_scheme=use_projection_scheme, - projection_a=projection_a, - projection_b=projection_b, - projection_c=projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - warmup_iteration=a0 < warmup_iter, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object[0], - shifted_probes, - overlap[0], - amplitudes[0], - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - warmup_iteration=a0 < warmup_iter, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) - - self.error_iterations.append(error.item()) - if store_iterations: - if a0 < warmup_iter: - self.object_iterations.append( - (asnumpy(self._object[0].copy()), None) - ) - else: - self.object_iterations.append( - ( - asnumpy(self._object[0].copy()), - asnumpy(self._object[1].copy()), - ) - ) - self.probe_iterations.append(self.probe_centered) - - # store result - if a0 < warmup_iter: - self.object = (asnumpy(self._object[0]), None) - else: - self.object = (asnumpy(self._object[0]), asnumpy(self._object[1])) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object[0]) - else: - obj = self.object[0] - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (12, 5)) - cmap_e = kwargs.pop("cmap_e", "magma") - cmap_m = kwargs.pop("cmap_m", "PuOr") - - if self._object_type == "complex": - obj_e = np.angle(self.object[0]) - obj_m = self.object[1] - else: - obj_e, obj_m = self.object - - rotated_electrostatic = self._crop_rotate_object_fov(obj_e, padding=padding) - rotated_magnetic = self._crop_rotate_object_fov(obj_m, padding=padding) - rotated_shape = rotated_electrostatic.shape - - min_e = rotated_electrostatic.min() - max_e = rotated_electrostatic.max() - max_m = np.abs(rotated_magnetic).max() - min_m = -max_m - - vmin_e = kwargs.pop("vmin_e", min_e) - vmax_e = kwargs.pop("vmax_e", max_e) - vmin_m = kwargs.pop("vmin_m", min_m) - vmax_m = kwargs.pop("vmax_m", max_m) - - chroma_boost = kwargs.pop("chroma_boost", 1) - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=3, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - 1, - 1, - (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]), - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=3, - nrows=1, - width_ratios=[ - 1, - 1, - (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]), - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=2, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Electrostatic Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_electrostatic, - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Magnetic Object - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - rotated_magnetic, - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed magnetic potential") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - ax = fig.add_subplot(spec[0, 2]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - # Electrostatic Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_electrostatic, - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Magnetic Object - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - rotated_magnetic, - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed magnetic potential") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - ax = fig.add_subplot(spec[1, :]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - raise NotImplementedError() - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self - - @property - def self_consistency_errors(self): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Overlaps - _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap[0]) - - # Normalized mean-squared errors - error = xp.sum( - xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - error /= self._mean_diffraction_intensity - - return asnumpy(error) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[0][start:end] - - # Overlaps - _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap[0]) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped[0]) - else: - projected_cropped_potential = self.object_cropped[0] - - return projected_cropped_potential - - @property - def object_cropped(self): - """Cropped and rotated object""" - - obj_e, obj_m = self._object - obj_e = self._crop_rotate_object_fov(obj_e) - obj_m = self._crop_rotate_object_fov(obj_m) - return (obj_e, obj_m) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py deleted file mode 100644 index 36baac21e..000000000 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ /dev/null @@ -1,2226 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely (single-slice) ptychography. -""" - -import warnings -from typing import Mapping, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg - -try: - import cupy as cp -except (ImportError, ModuleNotFoundError): - cp = np - -from emdfile import Custom, tqdmnd -from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, -) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class SingleslicePtychographicReconstruction(PtychographicReconstruction): - """ - Iterative Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py) is the padded-object size we position our ROI around in. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - datacube: DataCube - Input 4D diffraction pattern intensities - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = () - - def __init__( - self, - energy: float, - datacube: DataCube = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - object_padding_px: Tuple[int, int] = None, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (9, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1, chroma_boost=chroma_boost) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="gray", - ) - ax2.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_xlim((extent[0], extent[1])) - ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * object_patches - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, - tv_denoise_weight: float = None, - tv_denoise_inner_iter: float = 40, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - shifted_probes, - overlap, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append(self._crop_rotate_object_fov(obj, padding=padding)) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - - else: - probe_array = Complex2RGB( - probes[grid_range[n]], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py new file mode 100644 index 000000000..7249a9064 --- /dev/null +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -0,0 +1,1628 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely magnetic ptychographic tomography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, + Object3DMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + project_vector_field_divergence_periodic_3D, +) + + +class MagneticPtychographicTomography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object3DConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + Object3DMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Magnetic Ptychographic Tomography Reconstruction Class. + + List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (Px,Py,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py,Py) is the padded-object electrostatic potential volume, + where x-axis is the tilt. + + Parameters + ---------- + datacube: List of DataCubes + Input list of 4D diffraction pattern intensities for different tilts + energy: float + The electron energy of the wave functions in eV + num_slices: int + Number of super-slices to use in the forward model + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py,Py) + If None, initialized to 1.0 + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: list of np.ndarray, optional + Probe positions in Å for each diffraction intensity per tilt + If None, initialized to a grid scan centered along tilt axis + verbose: bool, optional + If True, class methods will inherit this and print additional information + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + name: str, optional + Class name + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_slices", + "_tilt_orientation_matrices", + "_num_measurements", + ) + + def __init__( + self, + energy: float, + num_slices: int, + tilt_orientation_matrices: Sequence[np.ndarray], + datacube: Sequence[DataCube] = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + object_type: str = "potential", + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: Sequence[np.ndarray] = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "magnetic-ptychographic-tomography_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + num_tilts = len(tilt_orientation_matrices) + if initial_scan_positions is None: + initial_scan_positions = [None] * num_tilts + + if object_type != "potential": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_slices = num_slices + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) + self._num_measurements = num_tilts + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_probe_overlaps: bool = True, + rotation_real_space_degrees: float = None, + diffraction_patterns_rotate_degrees: float = None, + diffraction_patterns_transpose: bool = None, + force_com_shifts: Sequence[float] = None, + vectorized_com_calculation: bool = True, + progress_bar: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + + Additionally, it initializes an (Px,Py, Py) array of 1.0 + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + rotation_real_space_degrees: float (degrees), optional + In plane rotation around z axis between x axis and tilt axis in + real space (forced to be in xy plane) + diffraction_patterns_rotate_degrees: float, optional + Relative rotation angle between real and reciprocal space + diffraction_patterns_transpose: bool, optional + Whether diffraction intensities need to be transposed. + force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. One tuple per tilt. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: OverlapTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) + self._rotation_best_transpose = diffraction_patterns_transpose + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="tilt", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first tilt + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + ) + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # initialize object + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + main_tilt_axis=None, + ) + + if self._object is None: + self._object = xp.full((4,) + obj.shape, obj) + else: + self._object = obj + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + self._num_voxels = self._object.shape[1] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + + self._slice_thicknesses = np.tile( + thickness / self._num_slices, self._num_slices - 1 + ) + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + if object_fov_mask is None: + probe_overlap_3D = xp.zeros_like(self._object[0]) + old_rot_matrix = np.eye(3) # identity + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + rot_matrix = self._tilt_orientation_matrices[index] + + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + rot_matrix @ old_rot_matrix.T, + ) + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + num_diffraction_patterns = idx_end - idx_start + shuffled_indices = np.arange(idx_start, idx_end) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + shifted_probes = fft_shift( + self._probes_all[index], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix + + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) + + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_3D_blurred = gaussian_filter(probe_overlap_3D, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max() + ) + + else: + self._object_fov_mask = np.asarray(object_fov_mask) + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probes_all[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, + chroma_boost=chroma_boost, + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _object_constraints_vector( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Calls Object3DConstraints _object_constraints for each object.""" + xp = self._xp + + # electrostatic + current_object[0] = self._object_constraints( + current_object[0], + gaussian_filter, + gaussian_filter_sigma_e, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_highpass_e, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ) + + # magnetic + for index in range(1, 4): + current_object[index] = self._object_constraints( + current_object[index], + gaussian_filter, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_m, + q_highpass_m, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + False, + 0.0, + None, + **kwargs, + ) + + # divergence-free + current_object[1:] = project_vector_field_divergence_periodic_3D( + current_object[1:], xp=xp + ) + + return current_object + + def _constraints(self, current_object, current_probe, current_positions, **kwargs): + """Wrapper function to bypass _object_constraints""" + + current_object = self._object_constraints_vector(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints(current_positions, **kwargs) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + tv_denoise: bool = True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + collective_measurement_updates: bool = True, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + object_positivity: bool, optional + If True, forces object to be positive + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_measurement_updates: bool + if True perform collective tilt updates + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + + Returns + -------- + self: OverlapMagneticTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychographic tomography is currently only implemented for gradient descent." + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + indices = np.arange(self._num_measurements) + np.random.shuffle(indices) + + old_rot_matrix = np.eye(3) # identity + + for index in indices: + self._active_measurement_index = index + + measurement_error = 0.0 + + rot_matrix = self._tilt_orientation_matrices[ + self._active_measurement_index + ] + self._object = self._rotate_zxy_volume_util( + self._object, + rot_matrix @ old_rot_matrix.T, + ) + object_V = self._object[0] + + # last transformation matrix row + weight_x, weight_y, weight_z = rot_matrix[-1] + object_A = ( + weight_x * self._object[2] + + weight_y * self._object[3] + + weight_z * self._object[1] + ) + + object_sliced = self._project_sliced_object( + object_V + object_A, self._num_slices + ) + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + if not use_projection_scheme: + object_sliced_old = object_sliced.copy() + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + object_sliced, _probe = self._adjoint( + object_sliced, + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + measurement_error += batch_error + + if not use_projection_scheme: + object_sliced -= object_sliced_old + + object_update = self._expand_sliced_object( + object_sliced, self._num_voxels + ) + + weights = (1, weight_z, weight_x, weight_y) + for index, weight in zip(range(4), weights): + if collective_measurement_updates: + collective_object[index] += self._rotate_zxy_volume( + object_update * weight, + rot_matrix.T, + ) + else: + self._object[index] += object_update * weight + + old_rot_matrix = rot_matrix + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self._object = self._rotate_zxy_volume_util(self._object, old_rot_matrix.T) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints_vector( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + orientation_matrix=None, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + # get scaled arrays + + if orientation_matrix is not None: + ordered_obj = self._rotate_zxy_volume_vector( + self._object, + orientation_matrix, + ) + + # V(z,x,y), Ax(z,x,y), Ay(z,x,y), Az(z,x,y) + ordered_obj = asnumpy(ordered_obj) + ordered_obj[1:] = np.roll(ordered_obj[1:], -1, axis=0) + + else: + # V(z,x,y), Ax(z,x,y), Ay(z,x,y), Az(z,x,y) + ordered_obj = self.object.copy() + ordered_obj[1:] = np.roll(ordered_obj[1:], -1, axis=0) + + _, nz, nx, ny = ordered_obj.shape + img_array = np.zeros((nx + nx + nz, ny * 4), dtype=ordered_obj.dtype) + + axes = [1, 2, 0] + transposes = [False, True, False] + labels = [("z [A]", "y [A]"), ("x [A]", "z [A]"), ("x [A]", "y [A]")] + limits_v = [(0, nz), (nz, nz + nx), (nz + nx, nz + nx + nx)] + limits_h = [(0, ny), (0, nz), (0, ny)] + + titles = [ + [ + r"$V$ projected along $\hat{x}$", + r"$A_x$ projected along $\hat{x}$", + r"$A_y$ projected along $\hat{x}$", + r"$A_z$ projected along $\hat{x}$", + ], + [ + r"$V$ projected along $\hat{y}$", + r"$A_x$ projected along $\hat{y}$", + r"$A_y$ projected along $\hat{y}$", + r"$A_z$ projected along $\hat{y}$", + ], + [ + r"$V$ projected along $\hat{z}$", + r"$A_x$ projected along $\hat{z}$", + r"$A_y$ projected along $\hat{z}$", + r"$A_z$ projected along $\hat{z}$", + ], + ] + + for index in range(4): + for axis, transpose, limit_v, limit_h in zip( + axes, transposes, limits_v, limits_h + ): + start_v, end_v = limit_v + start_h, end_h = np.array(limit_h) + index * ny + + subarray = ordered_obj[index].sum(axis) + if transpose: + subarray = subarray.T + + img_array[start_v:end_v, start_h:end_h] = subarray + + if plot_convergence: + auto_figsize = (ny * 4 * 4 / nx, (nx + nx + nz) * 3.5 / nx + 1) + else: + auto_figsize = (ny * 4 * 4 / nx, (nx + nx + nz) * 3.5 / nx) + + figsize = kwargs.pop("figsize", auto_figsize) + cmap_e = kwargs.pop("cmap_e", "magma") + cmap_m = kwargs.pop("cmap_m", "PuOr") + vmin_e = kwargs.pop("vmin_e", None) + vmax_e = kwargs.pop("vmax_e", None) + + # remove common unused kwargs + kwargs.pop("plot_probe", None) + kwargs.pop("plot_fourier_probe", None) + kwargs.pop("remove_initial_probe_aberrations", None) + kwargs.pop("vertical_lims", None) + kwargs.pop("horizontal_lims", None) + + _, vmin_e, vmax_e = return_scaled_histogram_ordering( + img_array[:, :ny], vmin_e, vmax_e + ) + + _, _, _vmax_m = return_scaled_histogram_ordering(np.abs(img_array[:, ny:])) + vmin_m = kwargs.pop("vmin_m", -_vmax_m) + vmax_m = kwargs.pop("vmax_m", _vmax_m) + + if plot_convergence: + spec = GridSpec( + ncols=4, + nrows=4, + height_ratios=[nx, nz, nx, nx / 4], + hspace=0.15, + wspace=0.35, + ) + else: + spec = GridSpec( + ncols=4, nrows=3, height_ratios=[nx, nz, nx], hspace=0.15, wspace=0.35 + ) + + if fig is None: + fig = plt.figure(figsize=figsize) + + for sp in spec: + row, col = np.unravel_index(sp.num1, (4, 4)) + + if row < 3: + ax = fig.add_subplot(sp) + + start_v, end_v = limits_v[row] + start_h, end_h = np.array(limits_h[row]) + col * ny + subarray = img_array[start_v:end_v, start_h:end_h] + + extent = [ + 0, + self.sampling[1] * subarray.shape[1], + self.sampling[0] * subarray.shape[0], + 0, + ] + + im = ax.imshow( + subarray, + cmap=cmap_e if sp.is_first_col() else cmap_m, + vmin=vmin_e if sp.is_first_col() else vmin_m, + vmax=vmax_e if sp.is_first_col() else vmax_m, + extent=extent, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + ax.set_title(titles[row][col]) + + y_label, x_label = labels[row] + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[-1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _rotate_zxy_volume_util( + self, + current_object, + rot_matrix, + ): + """ """ + for index in range(4): + current_object[index] = self._rotate_zxy_volume( + current_object[index], rot_matrix + ) + + return current_object + + def _rotate_zxy_volume_vector(self, current_object, rot_matrix): + """Rotates vector field consistently. Note this is very expensive""" + + xp = self._xp + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + if xp is np: + from scipy.interpolate import RegularGridInterpolator + + current_object = self._asnumpy(current_object) + else: + try: + from cupyx.scipy.interpolate import RegularGridInterpolator + except ModuleNotFoundError: + from scipy.interpolate import RegularGridInterpolator + + xp = np # force xp to np for cupy <12.0 + current_object = self._asnumpy(current_object) + + _, nz, nx, ny = current_object.shape + + z, x, y = [xp.linspace(-1, 1, s, endpoint=False) for s in (nx, ny, nz)] + Z, X, Y = xp.meshgrid(z, x, y, indexing="ij") + coords = xp.array([Z.ravel(), X.ravel(), Y.ravel()]) + + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix @ swap_zxy_to_xyz) + rotated_vecs = tf.T.dot(coords).T + + Az = RegularGridInterpolator( + (z, x, y), current_object[1], bounds_error=False, fill_value=0 + ) + Ax = RegularGridInterpolator( + (z, x, y), current_object[2], bounds_error=False, fill_value=0 + ) + Ay = RegularGridInterpolator( + (z, x, y), current_object[3], bounds_error=False, fill_value=0 + ) + + xp = self._xp # switch back to device + obj = xp.zeros_like(current_object) + obj[0] = self._rotate_zxy_volume(xp.asarray(current_object[0]), rot_matrix) + + obj[1] = xp.asarray(Az(rotated_vecs).reshape(nz, nx, ny)) + obj[2] = xp.asarray(Ax(rotated_vecs).reshape(nz, nx, ny)) + obj[3] = xp.asarray(Ay(rotated_vecs).reshape(nz, nx, ny)) + + return obj diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py new file mode 100644 index 000000000..d718b1a9e --- /dev/null +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -0,0 +1,1889 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely magnetic ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MagneticPtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative Magnetic Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed electrostatic dimensions : (Px,Py) + Reconstructed magnetic dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + datacube: Sequence[DataCube] + Tuple of input 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + magnetic_contribution_sign: str, optional + One of '-+', '-0+', '0+' + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad objects with + If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (2,Px,Py) + If None, initialized to 1.0j for complex objects and 0.0 for potential objects + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_magnetic_contribution_sign",) + + def __init__( + self, + energy: float, + datacube: Sequence[DataCube] = None, + magnetic_contribution_sign: str = "-+", + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "magnetic_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._magnetic_contribution_sign = magnetic_contribution_sign + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._magnetic_contribution_sign == "-+": + self._recon_mode = 0 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic vector potential sign in first meaurement assumed to be negative.\n" + "Magnetic vector potential sign in second meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "-0+": + self._recon_mode = 1 + self._num_measurements = 3 + magnetic_contribution_msg = ( + "Magnetic vector potential sign in first meaurement assumed to be negative.\n" + "Magnetic vector potential assumed to be zero in second meaurement.\n" + "Magnetic vector potential sign in third meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "0+": + self._recon_mode = 2 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic vector potential assumed to be zero in first meaurement.\n" + "Magnetic vector potential sign in second meaurement assumed to be positive." + ) + else: + raise ValueError( + f"magnetic_contribution_sign must be either '-+', '-0+', or '0+', not {self._magnetic_contribution_sign}" + ) + + if self._verbose: + warnings.warn( + magnetic_contribution_msg, + UserWarning, + ) + + if len(self._datacube) != self._num_measurements: + raise ValueError( + f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." + ) + + dc_shapes = [dc.shape for dc in self._datacube] + if dc_shapes.count(dc_shapes[0]) != self._num_measurements: + raise ValueError("datacube intensities must be the same size.") + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if self._scan_positions is None: + self._scan_positions = [None] * self._num_measurements + + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="measurement", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first measurement + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + ) + + # estimate rotation / transpose using first measurement + if index == 0: + # silence warnings to play nice with progress bar + verbose = self._verbose + self._verbose = False + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + _com_x, + _com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + com_measured_x, + com_measured_y, + com_normalized_x, + com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=False, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + self._verbose = verbose + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + ) + + if self._object is None: + self._object = xp.full((2,) + obj.shape, obj) + else: + self._object = obj + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probes_all[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = xp.empty( + (self._num_measurements,) + shifted_probes.shape, dtype=current_object.dtype + ) + object_patches[0] = current_object[ + 0, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + object_patches[1] = current_object[ + 1, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + overlap_base = shifted_probes * object_patches[0] + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + overlap = overlap_base * xp.conj(object_patches[1]) + case (0, 1) | (1, 2) | (2, 1): # forward + overlap = overlap_base * object_patches[1] + case (1, 1) | (2, 0): # neutral + overlap = overlap_base + case _: + raise ValueError() + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_conj = xp.conj(shifted_probes) # P* + electrostatic_conj = xp.conj(object_patches[0]) # V* = exp(-i v) + + probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) + probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( + probe_electrostatic_abs**2, + positions_px, + ) + probe_electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 + + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 + ) + + probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + + if not fix_probe: + electrostatic_magnetic_abs = xp.abs(object_patches[0] * object_patches[1]) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_magnetic_normalization)) + ** 2 + ) + + if self._recon_mode > 0: + electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + if self._object_type == "potential": + # -i exp(-i v) exp(i m) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * object_patches[1] + * electrostatic_conj + * probe_conj + * exit_waves + ), + positions_px, + ) + + # i exp(-i v) exp(i m) P* + magnetic_update = -electrostatic_update + + else: + # M P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * object_patches[1] * exit_waves, + positions_px, + ) + + # V* P* + magnetic_update = xp.conj( + self._sum_overlapping_patches_bincounts( + probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # M V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * object_patches[1] * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (0, 1) | (1, 2) | (2, 1): # forward + magnetic_conj = xp.conj(object_patches[1]) # M* = exp(-i m) + + if self._object_type == "potential": + # -i exp(-i v) exp(-i m) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves + ), + positions_px, + ) + + # -i exp(-i v) exp(-i m) P* + magnetic_update = electrostatic_update + + else: + # M* P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_conj * exit_waves, + positions_px, + ) + + # V* P* + magnetic_update = self._sum_overlapping_patches_bincounts( + probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # M* V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_conj * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) + probe_normalization = self._sum_overlapping_patches_bincounts( + probe_abs**2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + # -i exp(-i v) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real(-1j * electrostatic_conj * probe_conj * exit_waves), + positions_px, + ) + + else: + # P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_normalization + ) + + if not fix_probe: + # V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_normalization + ) + + case _: + raise ValueError() + + return current_object, current_probe + + def _object_constraints( + self, + current_object, + pure_phase_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """MagneticObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object[0] = self._object_gaussian_constraint( + current_object[0], gaussian_filter_sigma_e, pure_phase_object + ) + current_object[1] = self._object_gaussian_constraint( + current_object[1], gaussian_filter_sigma_m, True + ) + if butterworth_filter: + current_object[0] = self._object_butterworth_constraint( + current_object[0], + q_lowpass_e, + q_highpass_e, + butterworth_order, + ) + current_object[1] = self._object_butterworth_constraint( + current_object[1], + q_lowpass_m, + q_highpass_m, + butterworth_order, + ) + if tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], tv_denoise_weight, tv_denoise_inner_iter + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object[0] = self._object_shrinkage_constraint( + current_object[0], + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object[0] = self._object_threshold_constraint( + current_object[0], pure_phase_object + ) + current_object[1] = self._object_threshold_constraint( + current_object[1], True + ) + elif object_positivity: + current_object[0] = self._object_positivity_constraint(current_object[0]) + + return current_object + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + pure_phase_object: bool = False, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + store_iterations: bool = False, + collective_measurement_updates: bool = True, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass_e: float + Cut-off frequency in A^-1 for low-pass filtering electrostatic object + q_lowpass_m: float + Cut-off frequency in A^-1 for low-pass filtering magnetic object + q_highpass_e: float + Cut-off frequency in A^-1 for high-pass filtering electrostatic object + q_highpass_m: float + Cut-off frequency in A^-1 for high-pass filtering magnetic object + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + collective_measurement_updates: bool + if True perform collective updates for all measurements + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + if object_type is not None: + self._switch_object_type(object_type) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychography is currently only implemented for gradient descent." + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + # randomize + measurement_indices = np.arange(self._num_measurements) + np.random.shuffle(measurement_indices) + + for measurement_index in measurement_indices: + self._active_measurement_index = measurement_index + + measurement_error = 0.0 + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + projection_a=projection_a, + projection_b=projection_b, + projection_c=projection_c, + ) + + # adjoint operator + object_update, _probe = self._adjoint( + self._object.copy(), + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + object_update -= self._object + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[ + batch_indices + ] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + measurement_error += batch_error + + if collective_measurement_updates: + collective_object += object_update + else: + self._object += object_update + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=pure_phase_object + and self._object_type == "complex", + ) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=pure_phase_object + and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (12, 5)) + cmap_e = kwargs.pop("cmap_e", "magma") + cmap_m = kwargs.pop("cmap_m", "PuOr") + chroma_boost = kwargs.pop("chroma_boost", 1) + + # get scaled arrays + probe = self._return_single_probe() + obj = self.object_cropped + if self._object_type == "complex": + obj = np.angle(obj) + + vmin_e = kwargs.pop("vmin_e", None) + vmax_e = kwargs.pop("vmax_e", None) + obj[0], vmin_e, vmax_e = return_scaled_histogram_ordering( + obj[0], vmin_e, vmax_e + ) + + _, _, _vmax_m = return_scaled_histogram_ordering(np.abs(obj[1])) + vmin_m = kwargs.pop("vmin_m", -_vmax_m) + vmax_m = kwargs.pop("vmax_m", _vmax_m) + + extent = [ + 0, + self.sampling[1] * obj.shape[2], + self.sampling[0] * obj.shape[1], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_m + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 2]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_e + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + @property + def object_cropped(self): + """Cropped and rotated object""" + avg_pos = self._return_average_positions() + cropped_e = self._crop_rotate_object_fov(self._object[0], positions_px=avg_pos) + cropped_m = self._crop_rotate_object_fov(self._object[1], positions_px=avg_pos) + + return np.array([cropped_e, cropped_m]) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..47dd67dd3 --- /dev/null +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -0,0 +1,1088 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ProbeMixedConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + Object2p5DMethodsMixin, + Object2p5DProbeMixedMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ProbeMethodsMixin, + ProbeMixedMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MixedstateMultislicePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeMixedConstraintsMixin, + ProbeConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + Object2p5DProbeMixedMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMixedMethodsMixin, + ProbeMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + theta_x: float + x tilt of propagator in mrad + theta_y: float + y tilt of propagator in mrad + middle_focus: bool + if True, adds half the sample thickness to the defocus + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_probes", + "_num_slices", + "_slice_thicknesses", + "_theta_x", + "_theta_y", + ) + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if middle_focus: + half_thickness = slice_thicknesses.mean() * num_slices / 2 + self._polar_parameters["C10"] -= half_thickness + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._positions_mask = positions_mask + self._object_padding_px = object_padding_px + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # precompute propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + self._theta_x, + self._theta_y, + ) + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe[0] intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax2, chroma_boost=chroma_boost) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe[0] intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter: bool = True, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices: bool = False, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object: bool = False, + tv_denoise_chambolle: bool = True, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=1, + tv_denoise: bool = True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool, optional + If True and kz_regularization_gamma is not None, applies kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices: bool, optional + If True, object forced to identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object: bool, optional + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True and tv_denoise_weight_chambolle is not None, object is smoothed using TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: int + If not None, pads object at top and bottom with this many zeros before applying denoising + tv_denoise: bool + If True and tv_denoise_weights is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + if object_type is not None: + self._switch_object_type(object_type) + + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=kz_regularization_filter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=identical_slices, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=pure_phase_object and self._object_type == "complex", + tv_denoise_chambolle=tv_denoise_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py new file mode 100644 index 000000000..6fbb72b5d --- /dev/null +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -0,0 +1,967 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely mixed-state ptychography. +""" + +from typing import Mapping, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ProbeMixedConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ProbeMethodsMixin, + ProbeMixedMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MixedstatePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeMixedConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + ObjectNDProbeMixedMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMixedMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Mixed-State Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + datacube: DataCube + Input 4D diffraction pattern intensities + num_probes: int, optional + Number of mixed-state probes + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes",) + + def __init__( + self, + energy: float, + datacube: DataCube = None, + num_probes: int = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "mixed-state_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered, + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, axs = plt.subplots(1, self._num_probes + 1, figsize=figsize) + + for i in range(self._num_probes): + axs[i].imshow( + complex_probe_rgb[i], + extent=probe_extent, + ) + axs[i].set_ylabel("x [A]") + axs[i].set_xlabel("y [A]") + axs[i].set_title(f"Initial probe[{i}] intensity") + + divider = make_axes_locatable(axs[i]) + cax = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax, chroma_boost=chroma_boost) + + axs[-1].imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + axs[-1].scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + axs[-1].set_ylabel("x [A]") + axs[-1].set_xlabel("y [A]") + axs[-1].set_xlim((extent[0], extent[1])) + axs[-1].set_ylim((extent[2], extent[3])) + axs[-1].set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + pure_phase_object: bool = False, + fix_probe_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + global_affine_transformation: bool = False, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe: bool, optional + If True, probe is fixed + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: int, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + If not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + if object_type is not None: + self._switch_object_type(object_type) + + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + orthogonalize_probe=orthogonalize_probe, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=pure_phase_object and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py new file mode 100644 index 000000000..03e636f57 --- /dev/null +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -0,0 +1,1063 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MultislicePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + theta_x: float + x tilt of propagator in mrad + theta_y: float + y tilt of propagator in mrad + middle_focus: bool + if True, adds half the sample thickness to the defocus + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_slices", + "_slice_thicknesses", + "_theta_x", + "_theta_y", + ) + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + theta_x: float = None, + theta_y: float = None, + middle_focus: bool = False, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if middle_focus: + half_thickness = slice_thicknesses.mean() * num_slices / 2 + self._polar_parameters["C10"] -= half_thickness + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._positions_mask = positions_mask + self._object_padding_px = object_padding_px + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # precompute propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + self._theta_x, + self._theta_y, + ) + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered, + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe.copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, + chroma_boost=chroma_boost, + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter: bool = True, + kz_regularization_gamma: float = None, + identical_slices: bool = False, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object: bool = False, + tv_denoise_chambolle: bool = True, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=1, + tv_denoise: bool = True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vacuum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool, optional + If True and kz_regularization_gamma is not None, applies kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices: int, optional + If True, object forced to identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object: bool, optional + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True and tv_denoise_weight_chambolle is not None, object is smoothed using TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: int + If not None, pads object at top and bottom with this many zeros before applying denoising + tv_denoise: bool + If True and tv_denoise_weights is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + If not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + if object_type is not None: + self._switch_object_type(object_type) + + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=kz_regularization_filter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma, + identical_slices=identical_slices, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=pure_phase_object and self._object_type == "complex", + tv_denoise_chambolle=tv_denoise_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/parallax.py similarity index 93% rename from py4DSTEM/process/phase/iterative_parallax.py rename to py4DSTEM/process/phase/parallax.py index 34454088a..58181d812 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -14,7 +14,7 @@ from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import ( AffineTransform, bilinear_kernel_density_estimate, @@ -36,8 +36,6 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.simplefilter(action="always", category=UserWarning) - _aberration_names = { (1, 0): "C1 ", (1, 2): "stig ", @@ -56,7 +54,7 @@ } -class ParallaxReconstruction(PhaseReconstruction): +class Parallax(PhaseReconstruction): """ Iterative parallax reconstruction class. @@ -79,27 +77,23 @@ def __init__( self, energy: float, datacube: DataCube = None, - verbose: bool = False, + verbose: bool = True, object_padding_px: Tuple[int, int] = (32, 32), device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "parallax_reconstruction", ): Custom.__init__(self, name=name) - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter + if storage is None: + storage = device - self._gaussian_filter = gaussian_filter - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + if storage != device: + raise NotImplementedError() - self._gaussian_filter = gaussian_filter - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) self.set_save_defaults() @@ -109,7 +103,6 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose - self._device = device self._object_padding_px = object_padding_px self._preprocessed = False @@ -271,6 +264,9 @@ def preprocess( plot_average_bf: bool = True, realspace_mask: np.ndarray = None, apply_realspace_mask_to_stack: bool = True, + vectorized_com_calculation: bool = True, + device: str = None, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -306,6 +302,12 @@ def preprocess( apply_realspace_mask_to_stack: bool, optional If this value is set to true, output BF images will be masked by the edge filter and realspace_mask if it is passed in. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -313,7 +315,11 @@ def preprocess( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device asnumpy = self._asnumpy if self._datacube is None: @@ -330,6 +336,8 @@ def preprocess( require_calibrations=True, ) + self._intensities = xp.asarray(self._intensities) + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) self._scan_shape = np.array(self._intensities.shape[:2]) @@ -348,6 +356,7 @@ def preprocess( fit_function=descan_correction_fit_function, com_shifts=None, com_measured=None, + vectorized_calculation=vectorized_com_calculation, ) com_fitted_x = asnumpy(com_fitted_x) @@ -355,7 +364,6 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - # center_x, center_y = self._region_of_interest_shape / 2 center_x = com_fitted_x.mean() center_y = com_fitted_y.mean() @@ -717,178 +725,17 @@ def preprocess( ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") plt.tight_layout() - self._preprocessed = True - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) return self - def tune_angle_and_defocus( - self, - angle_guess=None, - defocus_guess=None, - angle_step_size=5, - defocus_step_size=100, - num_angle_values=5, - num_defocus_values=5, - return_values=False, - plot_reconstructions=True, - plot_convergence=True, - **kwargs, - ): - """ - Run parallax reconstruction over a parameters space of pre-determined angles - and defocus - - Parameters - ---------- - angle_guess: float (degrees), optional - initial starting guess for rotation angle between real and reciprocal space - if None, uses 0 - defocus_guess: float (A), optional - initial starting guess for defocus (defocus dF) - if None, uses 0 - angle_step_size: float (degrees), optional - size of change of rotation angle between real and reciprocal space for - each step in parameter space - defocus_step_size: float (A), optional - size of change of defocus for each step in parameter space - num_angle_values: int, optional - number of values of angle to test, must be >= 1. - num_defocus_values: int,optional - number of values of defocus to test, must be >= 1 - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, makes 2D plot of error metrix - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - asnumpy = self._asnumpy - - if angle_guess is None: - angle_guess = 0 - if defocus_guess is None: - defocus_guess = 0 - - if num_angle_values == 1: - angle_step_size = 0 - - if num_defocus_values == 1: - defocus_step_size = 0 - - angles = np.linspace( - angle_guess - angle_step_size * (num_angle_values - 1) / 2, - angle_guess + angle_step_size * (num_angle_values - 1) / 2, - num_angle_values, - ) - - defocus_values = np.linspace( - defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2, - defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2, - num_defocus_values, - ) - if return_values or plot_convergence: - recon_BF = [] - convergence = [] - - if plot_reconstructions: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 4 * num_angle_values) - ) - - fig = plt.figure(figsize=figsize) - - # run loop and plot along the way - self._verbose = False - for flat_index, (angle, defocus) in enumerate( - tqdmnd(angles, defocus_values, desc="Tuning angle and defocus") - ): - self.preprocess( - defocus_guess=defocus, - rotation_guess=angle, - plot_average_bf=False, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_angle_values, num_defocus_values) - ) - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_figax( - fig, - ax=object_ax, - **kwargs, - ) - - object_ax.set_title( - f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self._recon_error[0]:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - recon_BF.append(self.recon_BF) - if return_values or plot_convergence: - convergence.append(asnumpy(self._recon_error[0])) - - if plot_convergence: - fig, ax = plt.subplots() - ax.set_title("convergence") - im = ax.imshow( - np.array(convergence).reshape(angles.shape[0], defocus_values.shape[0]), - cmap="magma", - ) - - if angles.shape[0] > 1: - ax.set_ylabel("angles") - ax.set_yticks(np.arange(angles.shape[0])) - ax.set_yticklabels([f"{angle:.1f} °" for angle in angles]) - else: - ax.set_yticks([]) - ax.set_ylabel(f"angle {angles[0]:.1f}") - - if defocus_values.shape[0] > 1: - ax.set_xlabel("defocus values") - ax.set_xticks(np.arange(defocus_values.shape[0])) - ax.set_xticklabels([f"{df:.1f}" for df in defocus_values]) - else: - ax.set_xticks([]) - ax.set_xlabel(f"defocus value: {defocus_values[0]:.1f}") - - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - fig.colorbar(im, cax=cax) - - fig.tight_layout() - - if return_values: - convergence = np.array(convergence).reshape( - angles.shape[0], defocus_values.shape[0] - ) - return recon_BF, convergence - def reconstruct( self, max_alignment_bin: int = None, min_alignment_bin: int = 1, - max_iter_at_min_bin: int = 2, + num_iter_at_min_bin: int = 2, alignment_bin_values: list = None, cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), @@ -898,6 +745,8 @@ def reconstruct( plot_aligned_bf: bool = True, plot_convergence: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -910,7 +759,7 @@ def reconstruct( If None, the bright field disk radius is used min_alignment_bin: int, optional Minimum bin size for bright field alignment - max_iter_at_min_bin: int, optional + num_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size alignment_bin_values: list, optional If not None, explicitly sets the iteration bin values @@ -930,6 +779,10 @@ def reconstruct( If True, the convergence error is also plotted reset: bool, optional If True, the reconstruction is reset + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -937,6 +790,9 @@ def reconstruct( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp asnumpy = self._asnumpy @@ -947,6 +803,7 @@ def reconstruct( self._stack_mask = self._stack_mask_initial.copy() self._recon_mask = self._recon_mask_initial.copy() self._xy_shifts = self._xy_shifts_initial.copy() + elif reset is None: if hasattr(self, "error_iterations"): warnings.warn( @@ -1011,9 +868,9 @@ def reconstruct( bin_max = np.ceil(np.log(max_alignment_bin) / np.log(2)) bin_vals = 2 ** np.arange(bin_min, bin_max)[::-1] - if max_iter_at_min_bin > 1: + if num_iter_at_min_bin > 1: bin_vals = np.hstack( - (bin_vals, np.repeat(bin_vals[-1], max_iter_at_min_bin - 1)) + (bin_vals, np.repeat(bin_vals[-1], num_iter_at_min_bin - 1)) ) if plot_aligned_bf: @@ -1191,9 +1048,7 @@ def reconstruct( self.recon_BF = asnumpy(self._recon_BF) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -1268,7 +1123,7 @@ def subpixel_alignment( """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter BF_sampling = 1 / asnumpy(self._kr).max() / 2 DF_sampling = 1 / ( @@ -1281,7 +1136,7 @@ def subpixel_alignment( if self._DF_upsample_limit < 1: warnings.warn( ( - f"Dark-field upsampling limit of {self._DF_upsampling_limit:.2f} " + f"Dark-field upsampling limit of {self._DF_upsample_limit:.2f} " "is less than 1, implying a scan step-size smaller than Nyquist. " "setting to 1." ), @@ -1899,6 +1754,8 @@ def subpixel_alignment( spec.tight_layout(fig) + self.clear_device_mem(self._device, self._clear_fft_cache) + def _interpolate_array( self, image, @@ -1933,7 +1790,7 @@ def _kernel_density_estimate( """ """ xp = self._xp - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter if lanczos_alpha is not None: return lanczos_kernel_density_estimate( @@ -2532,9 +2389,7 @@ def score_CTF(coefs): + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) def _calculate_CTF(self, alpha_shape, sampling, *coefs): xp = self._xp @@ -2700,10 +2555,6 @@ def aberration_correct( self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) self.recon_phase_corrected = asnumpy(self._recon_phase_corrected) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) @@ -2733,6 +2584,8 @@ def aberration_correct( ax.set_xlabel("y [A]") ax.set_title("Parallax-Corrected Phase Image") + self.clear_device_mem(self._device, self._clear_fft_cache) + def depth_section( self, depth_angstroms=np.arange(-250, 260, 100), @@ -3010,7 +2863,8 @@ def show_shifts( dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( - xp.arange(self._dp_mean.shape[1]), xp.arange(self._dp_mean.shape[0]) + xp.arange(self._region_of_interest_shape[1]), + xp.arange(self._region_of_interest_shape[0]), ) freq_mask = xp.logical_and(xx % plot_arrow_freq == 0, yy % plot_arrow_freq == 0) masked_ind = xp.logical_and(freq_mask, self._dp_mask) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 91a71cb30..aa369872d 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -1,10 +1,11 @@ from functools import partial +from itertools import product from typing import Callable, Union import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform from skopt import gp_minimize from skopt.plots import plot_convergence as skopt_plot_convergence @@ -102,6 +103,151 @@ def __init__( self._set_optimizer_defaults() + def _generate_inclusive_boundary_grid( + self, + parameter, + n_points, + ): + """ """ + + # Categorical + if hasattr(parameter, "categories"): + return np.array(parameter.categories) + + # Real or Integer + else: + return np.unique( + np.linspace(parameter.low, parameter.high, n_points).astype( + parameter.dtype + ) + ) + + def grid_search( + self, + n_points: Union[tuple, int] = 3, + error_metric: Union[Callable, str] = "log", + plot_reconstructed_objects: bool = True, + return_reconstructed_objects: bool = False, + **kwargs: dict, + ): + """ + Run optimizer + + Parameters + ---------- + n_initial_points: int + Number of uniformly spaced trial points to run on a grid + error_metric: Callable or str + Function used to compute the reconstruction error. + When passed as a string, may be one of: + 'log': log(NMSE) of final object + 'linear': NMSE of final object + 'log-converged': log(NMSE) of final object if + NMSE is decreasing, 0 if NMSE increasing + 'linear-converged': NMSE of final object if + NMSE is decreasing, 1 if NMSE increasing + 'TV': sum( abs( grad( object ) ) ) / sum( abs( object ) ) + 'std': negative standard deviation of cropped object + 'std-phase': negative standard deviation of + phase of the cropped object + 'entropy-phase': entropy of the phase of the + cropped object + When passed as a Callable, a function that takes the + PhaseReconstruction object as its only argument + and returns the error metric as a single float + + """ + + num_params = len(self._parameter_list) + + if isinstance(n_points, int): + n_points = [n_points] * num_params + elif len(n_points) != num_params: + raise ValueError() + + params_grid = [ + self._generate_inclusive_boundary_grid(param, n_pts) + for param, n_pts in zip(self._parameter_list, n_points) + ] + params_grid = list(product(*params_grid)) + num_evals = len(params_grid) + + error_metric = self._get_error_metric(error_metric) + pbar = tqdm(total=num_evals, desc="Searching parameters") + + def evaluation_callback(ptycho): + if plot_reconstructed_objects or return_reconstructed_objects: + pbar.update(1) + return ( + ptycho._return_projected_cropped_potential(), + error_metric(ptycho), + ) + else: + pbar.update(1) + error_metric(ptycho) + + self._grid_search_function = self._get_optimization_function( + self._reconstruction_type, + self._parameter_list, + self._init_static_args, + self._affine_static_args, + self._preprocess_static_args, + self._reconstruction_static_args, + self._init_optimize_args, + self._affine_optimize_args, + self._preprocess_optimize_args, + self._reconstruction_optimize_args, + evaluation_callback, + ) + + grid_search_res = list(map(self._grid_search_function, params_grid)) + pbar.close() + + if plot_reconstructed_objects: + if len(n_points) == 2: + nrows, ncols = n_points + else: + nrows = kwargs.pop("nrows", int(np.sqrt(num_evals))) + ncols = kwargs.pop("ncols", int(np.ceil(num_evals / nrows))) + if nrows * ncols < num_evals: + raise ValueError() + + spec = GridSpec( + ncols=ncols, + nrows=nrows, + hspace=0.15, + wspace=0.15, + ) + + sx, sy = grid_search_res[0][0].shape + + separator = kwargs.pop("separator", "\n") + cmap = kwargs.pop("cmap", "magma") + figsize = kwargs.pop("figsize", (2.5 * ncols, 3 / sy * sx * nrows)) + fig = plt.figure(figsize=figsize) + + for index, (params, res) in enumerate(zip(params_grid, grid_search_res)): + row_index, col_index = np.unravel_index(index, (nrows, ncols)) + + ax = fig.add_subplot(spec[row_index, col_index]) + ax.imshow(res[0], cmap=cmap) + + title_substrings = [ + f"{param.name}: {val}" + for param, val in zip(self._parameter_list, params) + ] + title_substrings.append(f"error: {res[1]:.3e}") + title = separator.join(title_substrings) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(title) + spec.tight_layout(fig) + + if return_reconstructed_objects: + return grid_search_res + else: + return grid_search_res + def optimize( self, n_calls: int = 50, diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/phase_base_class.py similarity index 61% rename from py4DSTEM/process/phase/iterative_base_class.py rename to py4DSTEM/process/phase/phase_base_class.py index 8b836eae2..a67640d9b 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -2,17 +2,18 @@ Module for reconstructing phase objects from 4DSTEM datasets using iterative methods. """ +import sys import warnings import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex -from scipy.ndimage import rotate +from py4DSTEM.visualize import show_complex +from scipy.ndimage import zoom try: import cupy as cp + from cupy.fft.config import get_plan_cache except (ModuleNotFoundError, ImportError): cp = np @@ -20,12 +21,10 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( - PtychographicConstraints, -) from py4DSTEM.process.phase.utils import ( AffineTransform, - generate_batches, + copy_to_device, + get_array_module, polar_aliases, ) from py4DSTEM.process.utils import ( @@ -34,7 +33,8 @@ get_shifted_ar, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) +warnings.simplefilter("always", UserWarning) class PhaseReconstruction(Custom): @@ -43,6 +43,94 @@ class PhaseReconstruction(Custom): Defines various common functions and properties for subclasses to inherit. """ + def set_device(self, device, clear_fft_cache): + """ + Sets calculation device. + + Parameters + ---------- + device: str + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if clear_fft_cache is not None: + self._clear_fft_cache = clear_fft_cache + + if device is None: + return self + + if device == "cpu": + import scipy + + self._xp = np + self._scipy = scipy + + elif device == "gpu": + from cupyx import scipy + + self._xp = cp + self._scipy = scipy + + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + self._device = device + + return self + + def set_storage(self, storage): + """ + Sets storage device. + + Parameters + ---------- + storage: str + Device arrays will be stored on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if storage == "cpu": + self._xp_storage = np + + elif storage == "gpu": + if self._xp is np: + raise ValueError("storage='gpu' and device='cpu' is not supported") + self._xp_storage = cp + + else: + raise ValueError(f"storage must be either 'cpu' or 'gpu', not {storage}") + + self._asnumpy = copy_to_device + self._storage = storage + + return self + + def clear_device_mem(self, device, clear_fft_cache): + """ """ + if device == "gpu": + if clear_fft_cache: + cache = get_plan_cache() + cache.clear() + + xp = self._xp + xp._default_memory_pool.free_all_blocks() + xp._default_pinned_memory_pool.free_all_blocks() + + def copy_attributes_to_device(self, attrs, device): + """Utility function to copy a set of attrs to device""" + for attr in attrs: + array = copy_to_device(getattr(self, attr), device) + setattr(self, attr, array) + def attach_datacube(self, datacube: DataCube): """ Attaches a datacube to a class initialized without one. @@ -60,7 +148,13 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self - def reinitialize_parameters(self, device: str = None, verbose: bool = None): + def reinitialize_parameters( + self, + device: str = None, + storage: str = None, + clear_fft_cache: bool = None, + verbose: bool = None, + ): """ Reinitializes common parameters. This is useful when loading a previously-saved reconstruction (which set device='cpu' and verbose=True for compatibility) , @@ -69,7 +163,11 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): Parameters ---------- device: str, optional - If not None, imports and assigns appropriate device modules + If not None, assigns appropriate device modules + storage: str, optional + If not None, assigns appropriate storage modules + clear_fft_cache: bool, optional + If not None, sets the FFT caching parameter verbose: bool, optional If not None, sets the verbosity to verbose @@ -80,27 +178,10 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): """ if device is not None: - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + self.set_device(device, clear_fft_cache) - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - self._device = device + if storage is not None: + self.set_storage(storage) if verbose is not None: self._verbose = verbose @@ -144,7 +225,7 @@ def _preprocess_datacube_and_vacuum_probe( datacube, diffraction_intensities_shape=None, reshaping_method="fourier", - probe_roi_shape=None, + padded_diffraction_intensities_shape=None, vacuum_probe_intensity=None, dp_mask=None, com_shifts=None, @@ -161,13 +242,10 @@ def _preprocess_datacube_and_vacuum_probe( Note this does not affect the maximum scattering wavevector (Qx*dkx,Qy*dky) = (Sx*dkx',Sy*dky'), and thus the real-space sampling stays fixed. - The real space sampling, (dx, dy), combined with the resampled diffraction_intensities_shape, - sets the real-space probe region of interest (ROI) extent (dx*Sx, dy*Sy). - Occasionally, one may also want to specify a larger probe ROI extent, e.g when the probe - does not comfortably fit without self-ovelap artifacts, or when the scan step sizes are much - smaller than the real-space sampling (dx,dy). This can be achieved by specifying a - probe_roi_shape, which is larger than diffraction_intensities_shape, which will result in - zero-padding of the diffraction intensities. + Additionally, one may wish to zero-pad the diffraction intensity data. Note this does not increase + the information or resolution, but might be beneficial in a limited number of cases, e.g. when the + scan step sizes are much smaller than the real-space sampling (dx,dy). This can be achieved by specifying + a padded_diffraction_intensities_shape which is larger than diffraction_intensities_shape. Parameters ---------- @@ -178,7 +256,7 @@ def _preprocess_datacube_and_vacuum_probe( If None, no resamping is performed reshaping method: str, optional Reshaping method to use, one of 'bin', 'bilinear' or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape, (int,int), optional Padded diffraction intensities shape. If None, no padding is performed vacuum_probe_intensity, np.ndarray, optional @@ -193,11 +271,12 @@ def _preprocess_datacube_and_vacuum_probe( datacube: Datacube Resampled and Padded datacube """ + if com_shifts is not None: if np.isscalar(com_shifts[0]): com_shifts = ( - np.ones(self._datacube.Rshape) * com_shifts[0], - np.ones(self._datacube.Rshape) * com_shifts[1], + np.ones(datacube.Rshape) * com_shifts[0], + np.ones(datacube.Rshape) * com_shifts[1], ) if diffraction_intensities_shape is not None: @@ -227,12 +306,31 @@ def _preprocess_datacube_and_vacuum_probe( datacube = datacube.bin_Q(N=bin_factor) if vacuum_probe_intensity is not None: - vacuum_probe_intensity = vacuum_probe_intensity[ - ::bin_factor, ::bin_factor - ] + # crop edges if necessary + if Qx % bin_factor == 0: + vacuum_probe_intensity = vacuum_probe_intensity[ + : -(Qx % bin_factor), : + ] + if Qy % bin_factor == 0: + vacuum_probe_intensity = vacuum_probe_intensity[ + :, : -(Qy % bin_factor) + ] + + vacuum_probe_intensity = vacuum_probe_intensity.reshape( + Qx // bin_factor, bin_factor, Qy // bin_factor, bin_factor + ).sum(axis=(1, 3)) if dp_mask is not None: - dp_mask = dp_mask[::bin_factor, ::bin_factor] - else: + # crop edges if necessary + if Qx % bin_factor == 0: + dp_mask = dp_mask[: -(Qx % bin_factor), :] + if Qy % bin_factor == 0: + dp_mask = dp_mask[:, : -(Qy % bin_factor)] + + dp_mask = dp_mask.reshape( + Qx // bin_factor, bin_factor, Qy // bin_factor, bin_factor + ).sum(axis=(1, 3)) + + elif reshaping_method == "fourier": datacube = datacube.resample_Q( N=resampling_factor_x, method=reshaping_method ) @@ -249,10 +347,33 @@ def _preprocess_datacube_and_vacuum_probe( force_nonnegative=True, ) - if probe_roi_shape is not None: + elif reshaping_method == "bilinear": + datacube = datacube.resample_Q( + N=resampling_factor_x, method=reshaping_method + ) + if vacuum_probe_intensity is not None: + vacuum_probe_intensity = zoom( + vacuum_probe_intensity, + (resampling_factor_x, resampling_factor_x), + order=1, + ) + if dp_mask is not None: + dp_mask = zoom( + dp_mask, (resampling_factor_x, resampling_factor_x), order=1 + ) + + else: + raise ValueError( + ( + "reshaping_method needs to be one of 'bilinear', 'fourier', or 'bin', " + f"not {reshaping_method}." + ) + ) + + if padded_diffraction_intensities_shape is not None: Qx, Qy = datacube.shape[-2:] - Sx, Sy = probe_roi_shape - datacube = datacube.pad_Q(output_size=probe_roi_shape) + Sx, Sy = padded_diffraction_intensities_shape + datacube = datacube.pad_Q(output_size=padded_diffraction_intensities_shape) if vacuum_probe_intensity is not None or dp_mask is not None: pad_kx = Sx - Qx @@ -328,10 +449,11 @@ def _extract_intensities_and_calibrations_from_datacube( If require_calibrations is False and calibrations are not set """ - # Copies intensities to device casting to float32 - xp = self._xp + # explicit read-only self attributes up-front + verbose = self._verbose + energy = self._energy - intensities = xp.asarray(datacube.data, dtype=xp.float32) + intensities = np.asarray(datacube.data, dtype=np.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -348,7 +470,7 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - if self._verbose: + if verbose: warnings.warn( ( "Iterative reconstruction will not be quantitative unless you specify " @@ -373,35 +495,36 @@ def _extract_intensities_and_calibrations_from_datacube( # Reciprocal-space if force_angular_sampling is not None or force_reciprocal_sampling is not None: - # there is no xor keyword in Python! - angular = force_angular_sampling is not None - reciprocal = force_reciprocal_sampling is not None - assert (angular and not reciprocal) or ( - not angular and reciprocal - ), "Only one of angular or reciprocal calibration can be forced!" + if ( + force_angular_sampling is not None + and force_reciprocal_sampling is not None + ): + raise ValueError( + "Only one of angular or reciprocal calibration can be forced." + ) # angular calibration specified - if angular: + if force_angular_sampling is not None: self._angular_sampling = (force_angular_sampling,) * 2 self._angular_units = ("mrad",) * 2 - if self._energy is not None: + if energy is not None: self._reciprocal_sampling = ( force_angular_sampling - / electron_wavelength_angstrom(self._energy) + / electron_wavelength_angstrom(energy) / 1e3, ) * 2 self._reciprocal_units = ("A^-1",) * 2 # reciprocal calibration specified - if reciprocal: + if force_reciprocal_sampling is not None: self._reciprocal_sampling = (force_reciprocal_sampling,) * 2 self._reciprocal_units = ("A^-1",) * 2 - if self._energy is not None: + if energy is not None: self._angular_sampling = ( force_reciprocal_sampling - * electron_wavelength_angstrom(self._energy) + * electron_wavelength_angstrom(energy) * 1e3, ) * 2 self._angular_units = ("mrad",) * 2 @@ -413,7 +536,7 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - if self._verbose: + if verbose: warnings.warn( ( "Iterative reconstruction will not be quantitative unless you specify " @@ -432,11 +555,9 @@ def _extract_intensities_and_calibrations_from_datacube( self._reciprocal_sampling = (reciprocal_size,) * 2 self._reciprocal_units = ("A^-1",) * 2 - if self._energy is not None: + if energy is not None: self._angular_sampling = ( - reciprocal_size - * electron_wavelength_angstrom(self._energy) - * 1e3, + reciprocal_size * electron_wavelength_angstrom(energy) * 1e3, ) * 2 self._angular_units = ("mrad",) * 2 @@ -445,9 +566,9 @@ def _extract_intensities_and_calibrations_from_datacube( self._angular_sampling = (angular_size,) * 2 self._angular_units = ("mrad",) * 2 - if self._energy is not None: + if energy is not None: self._reciprocal_sampling = ( - angular_size / electron_wavelength_angstrom(self._energy) / 1e3, + angular_size / electron_wavelength_angstrom(energy) / 1e3, ) * 2 self._reciprocal_units = ("A^-1",) * 2 else: @@ -467,6 +588,7 @@ def _calculate_intensities_center_of_mass( fit_function: str = "plane", com_shifts: np.ndarray = None, com_measured: np.ndarray = None, + vectorized_calculation=True, ): """ Common preprocessing function to compute and fit diffraction intensities CoM @@ -483,6 +605,8 @@ def _calculate_intensities_center_of_mass( If not None, com_shifts are fitted on the measured CoM values. com_measured: tuple of ndarrays (CoMx measured, CoMy measured) If not None, com_measured are passed as com_measured_x, com_measured_y + vectorized_calculation: bool, optional + If True (default), the calculation is vectorized Returns ------- @@ -500,20 +624,17 @@ def _calculate_intensities_center_of_mass( Normalized vertical center of mass gradient """ + # explicit read-only self attributes up-front xp = self._xp + device = self._device asnumpy = self._asnumpy - # for ptycho + reciprocal_sampling = self._reciprocal_sampling + if com_measured: com_measured_x, com_measured_y = com_measured else: - # Coordinates - kx = xp.arange(intensities.shape[-2], dtype=xp.float32) - ky = xp.arange(intensities.shape[-1], dtype=xp.float32) - kya, kxa = xp.meshgrid(ky, kx) - - # calculate CoM if dp_mask is not None: if dp_mask.shape != intensities.shape[-2:]: raise ValueError( @@ -522,19 +643,57 @@ def _calculate_intensities_center_of_mass( f"not {dp_mask.shape}" ) ) - intensities_mask = intensities * xp.asarray(dp_mask, dtype=xp.float32) - else: - intensities_mask = intensities + dp_mask = xp.asarray(dp_mask, dtype=xp.float32) - intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) - com_measured_x = ( - xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) - / intensities_sum - ) - com_measured_y = ( - xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) - / intensities_sum - ) + # Coordinates + kx = xp.arange(intensities.shape[-2], dtype=xp.float32) + ky = xp.arange(intensities.shape[-1], dtype=xp.float32) + kya, kxa = xp.meshgrid(ky, kx) + + if vectorized_calculation: + # copy to device + intensities = copy_to_device(intensities, device) + + # calculate CoM + if dp_mask is not None: + intensities_mask = intensities * dp_mask + else: + intensities_mask = intensities + + intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) + + com_measured_x = ( + xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) + / intensities_sum + ) + com_measured_y = ( + xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) + / intensities_sum + ) + + else: + sx, sy = intensities.shape[:2] + com_measured_x = xp.zeros((sx, sy), dtype=xp.float32) + com_measured_y = xp.zeros((sx, sy), dtype=xp.float32) + + # loop of dps + for rx, ry in tqdmnd( + sx, + sy, + desc="Fitting center of mass", + unit="probe position", + disable=not self._verbose, + ): + masked_intensity = copy_to_device(intensities[rx, ry], device) + if dp_mask is not None: + masked_intensity *= dp_mask + summed_intensity = masked_intensity.sum() + com_measured_x[rx, ry] = ( + xp.sum(masked_intensity * kxa) / summed_intensity + ) + com_measured_y[rx, ry] = ( + xp.sum(masked_intensity * kya) / summed_intensity + ) if com_shifts is None: com_measured_x_np = asnumpy(com_measured_x) @@ -553,10 +712,10 @@ def _calculate_intensities_center_of_mass( # fix CoM units com_normalized_x = ( - xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + xp.nan_to_num(com_measured_x - com_fitted_x) * reciprocal_sampling[0] ) com_normalized_y = ( - xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + xp.nan_to_num(com_measured_y - com_fitted_y) * reciprocal_sampling[1] ) return ( @@ -574,7 +733,7 @@ def _solve_for_center_of_mass_relative_rotation( _com_measured_y: np.ndarray, _com_normalized_x: np.ndarray, _com_normalized_y: np.ndarray, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_rotation: bool = True, plot_center_of_mass: str = "default", maximize_divergence: bool = False, @@ -636,15 +795,22 @@ def _solve_for_center_of_mass_relative_rotation( Summary statistics """ + # explicit read-only self attributes up-front xp = self._xp asnumpy = self._asnumpy + verbose = self._verbose + scan_sampling = self._scan_sampling + scan_units = self._scan_units + + if rotation_angles_deg is None: + rotation_angles_deg = np.arange(-89.0, 90.0, 1.0) if force_com_rotation is not None: # Rotation known _rotation_best_rad = np.deg2rad(force_com_rotation) - if self._verbose: + if verbose: warnings.warn( ( "Best fit rotation forced to " @@ -658,7 +824,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = force_com_transpose - if self._verbose: + if verbose: warnings.warn( f"Transpose of intensities forced to {force_com_transpose}.", UserWarning, @@ -709,11 +875,12 @@ def _solve_for_center_of_mass_relative_rotation( else: _rotation_best_transpose = rotation_curl_transpose < rotation_curl - if self._verbose: + if verbose: if _rotation_best_transpose: - print("Diffraction intensities should be transposed.") - else: - print("No need to transpose diffraction intensities.") + warnings.warn( + "Diffraction intensities should be transposed.", + UserWarning, + ) else: # Rotation unknown @@ -722,7 +889,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = force_com_transpose - if self._verbose: + if verbose: warnings.warn( f"Transpose of intensities forced to {force_com_transpose}.", UserWarning, @@ -817,8 +984,11 @@ def _solve_for_center_of_mass_relative_rotation( rotation_best_deg = rotation_angles_deg[ind_min] _rotation_best_rad = rotation_angles_rad[ind_min] - if self._verbose: - print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) + if verbose: + warnings.warn( + f"Best fit rotation = {rotation_best_deg:.0f} degrees.", + UserWarning, + ) if plot_rotation: figsize = kwargs.get("figsize", (8, 2)) @@ -959,8 +1129,6 @@ def _solve_for_center_of_mass_relative_rotation( # Minimize Curl ind_min = xp.argmin(rotation_curl).item() ind_trans_min = xp.argmin(rotation_curl_transpose).item() - self._rotation_curl = rotation_curl - self._rotation_curl_transpose = rotation_curl_transpose if rotation_curl[ind_min] <= rotation_curl_transpose[ind_trans_min]: rotation_best_deg = rotation_angles_deg[ind_min] _rotation_best_rad = rotation_angles_rad[ind_min] @@ -971,13 +1139,18 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = True self._rotation_angles_deg = rotation_angles_deg + # Print summary - if self._verbose: - print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) + if verbose: + warnings.warn( + f"Best fit rotation = {rotation_best_deg:.0f} degrees.", + UserWarning, + ) if _rotation_best_transpose: - print("Diffraction intensities should be transposed.") - else: - print("No need to transpose diffraction intensities.") + warnings.warn( + "Diffraction intensities should be transposed.", + UserWarning, + ) # Plot Curl/Div rotation if plot_rotation: @@ -1049,18 +1222,14 @@ def _solve_for_center_of_mass_relative_rotation( + xp.cos(_rotation_best_rad) * _com_normalized_y ) - # 'Public'-facing attributes as numpy arrays - com_x = asnumpy(_com_x) - com_y = asnumpy(_com_y) - # Optionally, plot CoM if plot_center_of_mass == "all": figsize = kwargs.pop("figsize", (8, 12)) cmap = kwargs.pop("cmap", "RdBu_r") extent = [ 0, - self._scan_sampling[1] * _com_measured_x.shape[1], - self._scan_sampling[0] * _com_measured_x.shape[0], + scan_sampling[1] * _com_measured_x.shape[1], + scan_sampling[0] * _com_measured_x.shape[0], 0, ] @@ -1074,8 +1243,8 @@ def _solve_for_center_of_mass_relative_rotation( _com_measured_y, _com_normalized_x, _com_normalized_y, - com_x, - com_y, + _com_x, + _com_y, ], [ "CoM_x", @@ -1087,18 +1256,18 @@ def _solve_for_center_of_mass_relative_rotation( ], ): ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs) - ax.set_ylabel(f"x [{self._scan_units[0]}]") - ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_ylabel(f"x [{scan_units[0]}]") + ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) - elif plot_center_of_mass == "default": + elif plot_center_of_mass == "default" or plot_center_of_mass is True: figsize = kwargs.pop("figsize", (8, 4)) cmap = kwargs.pop("cmap", "RdBu_r") extent = [ 0, - self._scan_sampling[1] * com_x.shape[1], - self._scan_sampling[0] * com_x.shape[0], + scan_sampling[1] * _com_x.shape[1], + scan_sampling[0] * _com_x.shape[0], 0, ] @@ -1108,17 +1277,17 @@ def _solve_for_center_of_mass_relative_rotation( for ax, arr, title in zip( grid, [ - com_x, - com_y, + _com_x, + _com_y, ], [ "Corrected CoM_x", "Corrected CoM_y", ], ): - ax.imshow(arr, extent=extent, cmap=cmap, **kwargs) - ax.set_ylabel(f"x [{self._scan_units[0]}]") - ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs) + ax.set_ylabel(f"x [{scan_units[0]}]") + ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) return ( @@ -1126,8 +1295,6 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose, _com_x, _com_y, - com_x, - com_y, ) def _normalize_diffraction_intensities( @@ -1135,8 +1302,8 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, - crop_patterns, positions_mask, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1149,11 +1316,11 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1163,15 +1330,21 @@ def _normalize_diffraction_intensities( Mean intensity value """ - xp = self._xp + # explicit read-only self attributes up-front + asnumpy = self._asnumpy + mean_intensity = 0 - diffraction_intensities = self._asnumpy(diffraction_intensities) + diffraction_intensities = asnumpy(diffraction_intensities) + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + if positions_mask is not None: number_of_patterns = np.count_nonzero(positions_mask.ravel()) else: number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + # Aggressive cropping for when off-centered high scattering angle data was recorded if crop_patterns: crop_x = int( np.minimum( @@ -1202,17 +1375,14 @@ def _normalize_diffraction_intensities( crop_mask[-crop_w:, :crop_w] = True crop_mask[:crop_w:, -crop_w:] = True crop_mask[-crop_w:, -crop_w:] = True - self._crop_mask = crop_mask else: + crop_mask = None region_of_interest_shape = diffraction_intensities.shape[-2:] amplitudes = np.zeros( (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 ) - com_fitted_x = self._asnumpy(com_fitted_x) - com_fitted_y = self._asnumpy(com_fitted_y) - counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): @@ -1236,10 +1406,9 @@ def _normalize_diffraction_intensities( amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) counter += 1 - amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] - return amplitudes, mean_intensity + return amplitudes, mean_intensity, crop_mask def show_complex_CoM( self, @@ -1268,13 +1437,18 @@ def show_complex_CoM( default is scan sampling """ + # explicit read-only self attributes up-front + asnumpy = self._asnumpy + scan_sampling = self._scan_sampling + scan_units = self._scan_units + if com is None: - com = (self.com_x, self.com_y) + com = (self._com_x, self._com_y) if pixelsize is None: - pixelsize = self._scan_sampling[0] + pixelsize = scan_sampling[0] if pixelunits is None: - pixelunits = self._scan_units[0] + pixelunits = scan_units[0] figsize = kwargs.pop("figsize", (6, 6)) fig, ax = plt.subplots(figsize=figsize) @@ -1282,7 +1456,7 @@ def show_complex_CoM( complex_com = com[0] + 1j * com[1] show_complex( - complex_com, + asnumpy(complex_com), cbar=cbar, figax=(fig, ax), scalebar=scalebar, @@ -1293,10 +1467,10 @@ def show_complex_CoM( ) -class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): +class PtychographicReconstruction(PhaseReconstruction): """ Base ptychographic reconstruction class. - Inherits from PhaseReconstruction and PtychographicConstraints. + Inherits from PhaseReconstruction. Defines various common functions and properties for subclasses to inherit. """ @@ -1331,6 +1505,8 @@ def to_h5(self, group): "object_type": self._object_type, "verbose": self._verbose, "device": self._device, + "storage": self._storage, + "clear_fft_cache": self._clear_fft_cache, "name": self.name, "vacuum_probe_intensity": vacuum_probe_intensity, "positions": scan_positions, @@ -1373,15 +1549,7 @@ def to_h5(self, group): # reconstruction metadata is_stack = self._save_iterations and hasattr(self, "object_iterations") - if is_stack: - num_iterations = len(self.object_iterations) - iterations = list(range(0, num_iterations, self._save_iterations_frequency)) - if num_iterations - 1 not in iterations: - iterations.append(num_iterations - 1) - - error = [self.error_iterations[i] for i in iterations] - else: - error = getattr(self, "error", 0.0) + error = self.error_iterations self.metadata = Metadata( name="reconstruction_metadata", @@ -1406,6 +1574,8 @@ def to_h5(self, group): self._probe_emd = Array(name="reconstruction_probe", data=asnumpy(self._probe)) if is_stack: + num_iterations = len(self.object_iterations) + iterations = list(range(0, num_iterations, self._save_iterations_frequency)) iterations_labels = [f"iteration_{i:03}" for i in iterations] # object @@ -1485,6 +1655,8 @@ def _get_constructor_args(cls, group): "polar_parameters": polar_params, "verbose": True, # for compatibility "device": "cpu", # for compatibility + "storage": "cpu", # for compatibility + "clear_fft_cache": True, # for compatibility } class_specific_kwargs = {} @@ -1525,13 +1697,12 @@ def _populate_instance(self, group): self._exit_waves = None # Check if stack - if hasattr(error, "__len__"): + if "_object_iterations_emd" in dict_data.keys(): self.object_iterations = list(dict_data["_object_iterations_emd"].data) self.probe_iterations = list(dict_data["_probe_iterations_emd"].data) - self.error_iterations = error - self.error = error[-1] - else: - self.error = error + + self.error_iterations = error + self.error = error[-1] # Slim preprocessing to enable visualize self._positions_px_com = xp.mean(self._positions_px, axis=0) @@ -1539,6 +1710,29 @@ def _populate_instance(self, group): self.probe = self.probe_centered self._preprocessed = True + def _switch_object_type(self, object_type): + """ + Switches object type to/from "potential"/"complex" + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + xp = self._xp + + match (self._object_type, object_type): + case ("potential", "complex"): + self._object_type = "complex" + self._object = xp.exp(1j * self._object, dtype=xp.complex64) + case ("complex", "potential"): + self._object_type = "potential" + self._object = xp.angle(self._object) + case _: + self._object_type = self._object_type + + return self + def _set_polar_parameters(self, parameters: dict): """ Set the probe aberrations dictionary. @@ -1563,7 +1757,10 @@ def _set_polar_parameters(self, parameters: dict): raise ValueError("{} not a recognized parameter".format(symbol)) def _calculate_scan_positions_in_pixels( - self, positions: np.ndarray, positions_mask + self, + positions: np.ndarray, + positions_mask, + object_padding_px, ): """ Method to compute the initial guess of scan positions in pixels. @@ -1575,16 +1772,25 @@ def _calculate_scan_positions_in_pixels( If None, a raster scan using experimental parameters is constructed. positions_mask: np.ndarray, optional Boolean real space mask to select positions in datacube to skip for reconstruction + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions Returns ------- positions_in_px: (J,2) np.ndarray Initial guess of scan positions in pixels + object_padding_px: Tupe[int,int] + Updated object_padding_px """ + # explicit read-only self attributes up-front grid_scan_shape = self._grid_scan_shape rotation_angle = self._rotation_best_rad + transpose = self._rotation_best_transpose step_sizes = self._scan_sampling + region_of_interest_shape = self._region_of_interest_shape + sampling = self.sampling if positions is None: if grid_scan_shape is not None: @@ -1599,47 +1805,51 @@ def _calculate_scan_positions_in_pixels( else: raise ValueError() - if self._rotation_best_transpose: - x = (x - np.ptp(x) / 2) / self.sampling[1] - y = (y - np.ptp(y) / 2) / self.sampling[0] + if transpose: + x = (x - np.ptp(x) / 2) / sampling[1] + y = (y - np.ptp(y) / 2) / sampling[0] else: - x = (x - np.ptp(x) / 2) / self.sampling[0] - y = (y - np.ptp(y) / 2) / self.sampling[1] + x = (x - np.ptp(x) / 2) / sampling[0] + y = (y - np.ptp(y) / 2) / sampling[1] x, y = np.meshgrid(x, y, indexing="ij") + if positions_mask is not None: x = x[positions_mask] y = y[positions_mask] else: positions -= np.mean(positions, axis=0) - x = positions[:, 0] / self.sampling[1] - y = positions[:, 1] / self.sampling[0] + x = positions[:, 0] / sampling[1] + y = positions[:, 1] / sampling[0] if rotation_angle is not None: x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin( rotation_angle ) + y * np.cos(rotation_angle) - if self._rotation_best_transpose: + if transpose: positions = np.array([y.ravel(), x.ravel()]).T else: positions = np.array([x.ravel(), y.ravel()]).T + positions -= np.min(positions, axis=0) - if self._object_padding_px is None: - float_padding = self._region_of_interest_shape / 2 - self._object_padding_px = (float_padding, float_padding) - elif np.isscalar(self._object_padding_px[0]): - self._object_padding_px = ( - (self._object_padding_px[0],) * 2, - (self._object_padding_px[1],) * 2, + if object_padding_px is None: + float_padding = region_of_interest_shape / 2 + object_padding_px = (float_padding, float_padding) + elif np.isscalar(object_padding_px[0]): + object_padding_px = ( + (object_padding_px[0],) * 2, + (object_padding_px[1],) * 2, ) - positions[:, 0] += self._object_padding_px[0][0] - positions[:, 1] += self._object_padding_px[1][0] + positions[:, 0] += object_padding_px[0][0] + positions[:, 1] += object_padding_px[1][0] - return positions + return positions, object_padding_px - def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): + def _sum_overlapping_patches_bincounts_base( + self, patches: np.ndarray, positions_px + ): """ Base bincouts overlapping patches sum function, operating on real-valued arrays. Note this assumes the probe is corner-centered. @@ -1654,28 +1864,28 @@ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): out_array: (Px,Py) np.ndarray Summed array """ - xp = self._xp - x0 = xp.round(self._positions_px[:, 0]).astype("int") - y0 = xp.round(self._positions_px[:, 1]).astype("int") - + # explicit read-only self attributes up-front + xp = get_array_module(patches) roi_shape = self._region_of_interest_shape + object_shape = self._object_shape + + x0 = xp.round(positions_px[:, 0]).astype("int") + y0 = xp.round(positions_px[:, 1]).astype("int") + x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") flat_weights = patches.ravel() - indices = ( - (y0[:, None, None] + y_ind[None, None, :]) % self._object_shape[1] - ) + ( - (x0[:, None, None] + x_ind[None, :, None]) % self._object_shape[0] - ) * self._object_shape[ - 1 - ] + indices = ((y0[:, None, None] + y_ind[None, None, :]) % object_shape[1]) + ( + (x0[:, None, None] + x_ind[None, :, None]) % object_shape[0] + ) * object_shape[1] counts = xp.bincount( - indices.ravel(), weights=flat_weights, minlength=np.prod(self._object_shape) + indices.ravel(), weights=flat_weights, minlength=np.prod(object_shape) ) - return xp.reshape(counts, self._object_shape) + counts = xp.reshape(counts, object_shape).astype(xp.float32) + return counts - def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): + def _sum_overlapping_patches_bincounts(self, patches: np.ndarray, positions_px): """ Sum overlapping patches defined into object shaped array using bincounts. Calls _sum_overlapping_patches_bincounts_base on Real and Imaginary parts. @@ -1691,15 +1901,21 @@ def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): Summed array """ - xp = self._xp - if xp.iscomplexobj(patches): - real = self._sum_overlapping_patches_bincounts_base(xp.real(patches)) - imag = self._sum_overlapping_patches_bincounts_base(xp.imag(patches)) + if np.iscomplexobj(patches): + real = self._sum_overlapping_patches_bincounts_base( + patches.real, positions_px + ) + imag = self._sum_overlapping_patches_bincounts_base( + patches.imag, positions_px + ) return real + 1.0j * imag else: - return self._sum_overlapping_patches_bincounts_base(patches) + return self._sum_overlapping_patches_bincounts_base(patches, positions_px) - def _extract_vectorized_patch_indices(self): + def _extract_vectorized_patch_indices( + self, + positions_px, + ): """ Sets the vectorized row/col indices used for the overlap projection Note this assumes the probe is corner-centered. @@ -1711,15 +1927,17 @@ def _extract_vectorized_patch_indices(self): self._vectorized_patch_indices_col: np.ndarray Column indices for probe patches inside object array """ - xp = self._xp - x0 = xp.round(self._positions_px[:, 0]).astype("int") - y0 = xp.round(self._positions_px[:, 1]).astype("int") - + # explicit read-only self attributes up-front + xp_storage = self._xp_storage roi_shape = self._region_of_interest_shape - x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") - y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") - obj_shape = self._object_shape + + x0 = xp_storage.round(positions_px[:, 0]).astype("int") + y0 = xp_storage.round(positions_px[:, 1]).astype("int") + + x_ind = xp_storage.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") + y_ind = xp_storage.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") + vectorized_patch_indices_row = ( x0[:, None, None] + x_ind[None, :, None] ) % obj_shape[0] @@ -1729,903 +1947,197 @@ def _extract_vectorized_patch_indices(self): return vectorized_patch_indices_row, vectorized_patch_indices_col - def _crop_rotate_object_fov( + def _set_reconstruction_method_parameters( self, - array, - padding=0, + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, ): - """ - Crops and rotated object to FOV bounded by current pixel positions. + """""" - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - padding: int, optional - Optional padding outside pixel positions - - Returns - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) - asnumpy = self._asnumpy - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) - tf = AffineTransform(angle=angle) - rotated_points = tf( - asnumpy(self._positions_px), origin=asnumpy(self._positions_px_com), xp=np + return ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, ) - min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") - min_x = min_x if min_x > 0 else 0 - min_y = min_y if min_y > 0 else 0 - max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") - - rotated_array = rotate( - asnumpy(array), np.rad2deg(-angle), reshape=False, axes=(-2, -1) - )[..., min_x:max_x, min_y:max_y] - - if self._rotation_best_transpose: - rotated_array = rotated_array.swapaxes(-2, -1) - - return rotated_array - - def tune_angle_and_defocus( + def _report_reconstruction_summary( self, - angle_guess=None, - defocus_guess=None, - transpose=None, - angle_step_size=1, - defocus_step_size=20, - num_angle_values=5, - num_defocus_values=5, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, + max_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, ): - """ - Run reconstructions over a parameters space of angles and - defocus values. Should be run after preprocess step. + """ """ - Parameters - ---------- - angle_guess: float (degrees), optional - initial starting guess for rotation angle between real and reciprocal space - if None, uses current initialized values - defocus_guess: float (A), optional - initial starting guess for defocus - if None, uses current initialized values - angle_step_size: float (degrees), optional - size of change of rotation angle between real and reciprocal space for - each step in parameter space - defocus_step_size: float (A), optional - size of change of defocus for each step in parameter space - num_angle_values: int, optional - number of values of angle to test, must be >= 1. - num_defocus_values: int,optional - number of values of defocus to test, must be >= 1 - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction. - return_values: bool, optional - if True, returns objects, convergence + # object type + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - # calculate angles and defocus values to test - if angle_guess is None: - angle_guess = self._rotation_best_rad * 180 / np.pi - if defocus_guess is None: - defocus_guess = -self._polar_parameters["C10"] - if transpose is None: - transpose = self._rotation_best_transpose - - if num_angle_values == 1: - angle_step_size = 0 - - if num_defocus_values == 1: - defocus_step_size = 0 - - angles = np.linspace( - angle_guess - angle_step_size * (num_angle_values - 1) / 2, - angle_guess + angle_step_size * (num_angle_values - 1) / 2, - num_angle_values, - ) - - defocus_values = np.linspace( - defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2, - defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2, - num_defocus_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_defocus = -self._polar_parameters["C10"] - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values * 2, - height_ratios=[1, 1 / 4] * num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 5 * num_angle_values) + # stochastic gradient descent + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) ) else: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 4 * num_angle_values) + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ), + UserWarning, ) - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (angle, defocus) in enumerate( - tqdmnd(angles, defocus_values, desc="Tuning angle and defocus") - ): - self._polar_parameters["C10"] = -defocus - self._probe = None - self._object = None - self.preprocess( - force_com_rotation=angle, - force_com_transpose=transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - - self.reconstruct( - reset=True, - store_iterations=True, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_angle_values, num_defocus_values) + else: + # named projection set method + if reconstruction_parameter is not None: + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ), + UserWarning, ) - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self.error:.3e}" + # generalized projections (or the even more rare charge-flipping) + elif projection_a is not None: + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): " + f"{projection_a, projection_b, projection_c}." + ), + UserWarning, ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._polar_parameters["C10"] = -current_defocus - self._probe = None - self._object = None - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - if return_values: - return objects, convergence + # gradient descent + else: + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ), + UserWarning, + ) - def _position_correction( + def _constraints( self, - relevant_object, - relevant_probes, - relevant_overlap, - relevant_amplitudes, + current_object, + current_probe, current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - relevant_object: np.ndarray - Current object estimate - relevant_probes:np.ndarray - fractionally-shifted probes - relevant_overlap: np.ndarray - object * probe overlap - relevant_amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * relevant_object) - else: - complex_object = relevant_object - - obj_rolled_x_patches = complex_object[ - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - obj_rolled_y_patches = complex_object[ - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - overlap_fft = xp.fft.fft2(relevant_overlap) - - exit_waves_dx_fft = overlap_fft - xp.fft.fft2( - obj_rolled_x_patches * relevant_probes - ) - exit_waves_dy_fft = overlap_fft - xp.fft.fft2( - obj_rolled_y_patches * relevant_probes - ) - - overlap_fft_conj = xp.conj(overlap_fft) - estimated_intensity = xp.abs(overlap_fft) ** 2 - measured_intensity = relevant_amplitudes**2 - - flat_shape = (relevant_overlap.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * overlap_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * overlap_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - return current_positions - - def plot_position_correction( - self, - scale_arrows=1, - plot_arrow_freq=1, - verbose=True, - **kwargs, - ): - """ - Function to plot changes to probe positions during ptychography reconstruciton - - Parameters - ---------- - scale_arrows: float, optional - scaling factor to be applied on vectors prior to plt.quiver call - verbose: bool, optional - if True, prints AffineTransformation if positions have been updated - """ - if verbose: - if hasattr(self, "_tf"): - print(self._tf) - - asnumpy = self._asnumpy - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - initial_pos = asnumpy(self._positions_initial) - pos = self.positions - - figsize = kwargs.pop("figsize", (6, 6)) - color = kwargs.pop("color", (1, 0, 0, 1)) - - fig, ax = plt.subplots(figsize=figsize) - ax.quiver( - initial_pos[::plot_arrow_freq, 1], - initial_pos[::plot_arrow_freq, 0], - (pos[::plot_arrow_freq, 1] - initial_pos[::plot_arrow_freq, 1]) - * scale_arrows, - (pos[::plot_arrow_freq, 0] - initial_pos[::plot_arrow_freq, 0]) - * scale_arrows, - scale_units="xy", - scale=1, - color=color, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_xlim((extent[0], extent[1])) - ax.set_ylim((extent[2], extent[3])) - ax.set_aspect("equal") - ax.set_title("Probe positions correction") - - def _return_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - ): - """ - Returns complex fourier probe shifted to center of array from - corner-centered complex real space probe - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - - Returns - ------- - fourier_probe: np.ndarray - Fourier-transformed and center-shifted probe. - """ - xp = self._xp - - if probe is None: - probe = self._probe - else: - probe = xp.asarray(probe, dtype=xp.complex64) - - fourier_probe = xp.fft.fft2(probe) - - if remove_initial_probe_aberrations: - fourier_probe *= xp.conjugate(self._known_aberrations_array) - - return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) - - def _return_fourier_probe_from_centered_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - ): - """ - Returns complex fourier probe shifted to center of array from - centered complex real space probe - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - - Returns - ------- - fourier_probe: np.ndarray - Fourier-transformed and center-shifted probe. - """ - xp = self._xp - return self._return_fourier_probe( - xp.fft.ifftshift(probe, axes=(-2, -1)), - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - - def _return_centered_probe( - self, - probe=None, - ): - """ - Returns complex probe centered in middle of the array. - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - - Returns - ------- - centered_probe: np.ndarray - Center-shifted probe. - """ - xp = self._xp - - if probe is None: - probe = self._probe - else: - probe = xp.asarray(probe, dtype=xp.complex64) - - return xp.fft.fftshift(probe, axes=(-2, -1)) - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns absolute value of obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - - Returns - ------- - object_fft_amplitude: np.ndarray - Amplitude of Fourier-transformed and center-shifted obj. - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = self._crop_rotate_object_fov(asnumpy(obj)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped) - else: - projected_cropped_potential = self.object_cropped - - return projected_cropped_potential - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - - if errors is None: - errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) - - if projected_cropped_potential is None: - projected_cropped_potential = self._return_projected_cropped_potential() - - if kde_sigma is None: - kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] - - xp = self._xp - asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter - - ## Kernel Density Estimation - - # rotated basis - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) - - tf = AffineTransform(angle=angle) - rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) - - padding = xp.min(rotated_points, axis=0).astype("int") - - # bilinear sampling - pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( - 2 * padding - ) - pixel_size = pixel_output.prod() - - xa = rotated_points[:, 0] - ya = rotated_points[:, 1] - - # bilinear sampling - xF = xp.floor(xa).astype("int") - yF = xp.floor(ya).astype("int") - dx = xa - xF - dy = ya - yF - - # resampling - inds_1D = xp.ravel_multi_index( - xp.hstack( - [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - ), - pixel_output, - mode=["wrap", "wrap"], - ) - - weights = xp.hstack( - ( - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ) - ) - - pix_count = xp.reshape( - xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output - ) - - pix_output = xp.reshape( - xp.bincount( - inds_1D, - weights=weights * xp.tile(xp.asarray(errors), 4), - minlength=pixel_size, - ), - pixel_output, - ) - - # kernel density estimate - pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") - pix_count[pix_count == 0.0] = np.inf - pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") - pix_output /= pix_count - pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] - pix_output, _, _ = return_scaled_histogram_ordering( - pix_output.get(), normalize=True - ) - - ## Visualization - if plot_histogram: - spec = GridSpec( - ncols=1, - nrows=2, - height_ratios=[1, 4], - hspace=0.15, - ) - auto_figsize = (4, 5.25) - else: - spec = GridSpec( - ncols=1, - nrows=1, - ) - auto_figsize = (4, 4) - - figsize = kwargs.pop("figsize", auto_figsize) - - fig = plt.figure(figsize=figsize) - - if plot_histogram: - ax_hist = fig.add_subplot(spec[0]) - - counts, bins = np.histogram(errors, bins=50) - ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) - ax_hist.set_ylabel("Counts") - ax_hist.set_xlabel("Normalized Squared Error") - - ax = fig.add_subplot(spec[-1]) - - cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - - projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( - projected_cropped_potential, - vmin=vmin, - vmax=vmax, - ) - - extent = [ - 0, - self.sampling[1] * projected_cropped_potential.shape[1], - self.sampling[0] * projected_cropped_potential.shape[0], - 0, - ] - - ax.imshow( - projected_cropped_potential, - vmin=vmin, - vmax=vmax, - extent=extent, - alpha=1 - pix_output, - cmap=cmap, - **kwargs, - ) - - if plot_contours: - aligned_points = asnumpy(rotated_points - padding) - aligned_points[:, 0] *= self.sampling[0] - aligned_points[:, 1] *= self.sampling[1] - - ax.tricontour( - aligned_points[:, 1], - aligned_points[:, 0], - errors, - colors="grey", - levels=5, - # linestyles='dashed', - linewidths=0.5, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_xlim((extent[0], extent[1])) - ax.set_ylim((extent[2], extent[3])) - ax.xaxis.set_ticks_position("bottom") - - spec.tight_layout(fig) - - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, + initial_positions, **kwargs, ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - cbar: bool, optional - if True, adds colorbar - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - probe = asnumpy( - self._return_fourier_probe( - probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations - ) - ) - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - figsize = kwargs.pop("figsize", (6, 6)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - fig, ax = plt.subplots(figsize=figsize) - show_complex( - probe, - cbar=cbar, - figax=(fig, ax), - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def show_object_fft(self, obj=None, **kwargs): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: complex array, optional - if None is specified, uses the `object_fft` property - """ - if obj is None: - object_fft = self.object_fft - else: - object_fft = self._return_object_fft(obj) + """Wrapper function for all classes to inherit""" - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, + current_object = self._object_constraints(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints( + current_positions, initial_positions, **kwargs ) - @property - def probe_fourier(self): - """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy(self._return_fourier_probe(self._probe)) - - @property - def probe_fourier_residual(self): - """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy( - self._return_fourier_probe( - self._probe, remove_initial_probe_aberrations=True - ) - ) - - @property - def probe_centered(self): - """Current probe estimate shifted to the center""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy(self._return_centered_probe(self._probe)) - - @property - def object_fft(self): - """Fourier transform of current object estimate""" - - if not hasattr(self, "_object"): - return None - - return self._return_object_fft(self._object) + return current_object, current_probe, current_positions @property def angular_sampling(self): @@ -2658,9 +2170,3 @@ def positions(self): positions[:, 1] *= self.sampling[1] return asnumpy(positions) - - @property - def object_cropped(self): - """Cropped and rotated object""" - - return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py new file mode 100644 index 000000000..8a10b4df2 --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -0,0 +1,1374 @@ +import warnings + +import numpy as np +from py4DSTEM.process.phase.utils import ( + array_slice, + estimate_global_transformation_ransac, + fft_shift, + fit_aberration_surface, + regularize_probe_amplitude, +) +from py4DSTEM.process.utils import get_CoM + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + import os + + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception + + +class ObjectNDConstraintsMixin: + """ + Mixin class for object constraints applicable to 2D,2.5D, and 3D objects. + """ + + def _object_threshold_constraint(self, current_object, pure_phase_object): + """ + Ptychographic threshold constraint. + Used for avoiding the scaling ambiguity between probe and object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + pure_phase_object: bool + If True, object amplitude is set to unity + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + if self._object_type == "complex": + phase = xp.angle(current_object) + + if pure_phase_object: + amplitude = 1.0 + else: + amplitude = xp.minimum(xp.abs(current_object), 1.0) + + return amplitude * xp.exp(1.0j * phase) + else: + return current_object + + def _object_shrinkage_constraint(self, current_object, shrinkage_rad, object_mask): + """ + Ptychographic shrinkage constraint. + Used to ensure electrostatic potential is positive. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + if self._object_type == "complex": + phase = xp.angle(current_object) + amp = xp.abs(current_object) + + if object_mask is not None: + shrinkage_rad += phase[..., object_mask].mean() + + phase -= shrinkage_rad + + current_object = amp * xp.exp(1.0j * phase) + else: + if object_mask is not None: + shrinkage_rad += current_object[..., object_mask].mean() + + current_object -= shrinkage_rad + + return current_object + + def _object_positivity_constraint(self, current_object): + """ + Ptychographic positivity constraint. + Used to ensure potential is positive. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + if self._object_type == "complex": + return current_object + else: + return current_object.clip(0.0) + + def _object_gaussian_constraint( + self, current_object, gaussian_filter_sigma, pure_phase_object + ): + """ + Ptychographic smoothness constraint. + Used for blurring object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + gaussian_filter_sigma: float + Standard deviation of gaussian kernel in A + pure_phase_object: bool + If True, gaussian blur performed on phase only + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + gaussian_filter = self._scipy.ndimage.gaussian_filter + gaussian_filter_sigma /= self.sampling[0] + + if not pure_phase_object or self._object_type == "potential": + current_object = gaussian_filter(current_object, gaussian_filter_sigma) + else: + phase = xp.angle(current_object) + phase = gaussian_filter(phase, gaussian_filter_sigma) + current_object = xp.exp(1.0j * phase) + + return current_object + + def _object_butterworth_constraint( + self, + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ): + """ + Ptychographic butterworth filter. + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1]) + + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object, axis=(-2, -1), keepdims=True) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + + def _object_denoise_tv_chambolle( + self, + current_object, + weight, + axis, + padding, + eps=2.0e-4, + max_num_iter=200, + scaling=None, + ): + """ + Perform total-variation denoising on n-dimensional images. + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float, optional + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + axis: int or tuple + Axis for denoising, if None uses all axes + pad_object: bool + if True, pads object with zeros along axes of blurring + eps : float, optional + Relative difference of the value of the cost function that determines + the stop criterion. The algorithm stops when: + + (E_(n-1) - E_n) < eps * E_0 + + max_num_iter : int, optional + Maximal number of iterations used for the optimization. + scaling : tuple, optional + Scale weight of tv denoise on different axes + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + Notes + ----- + Rudin, Osher and Fatemi algorithm. + Adapted skimage.restoration.denoise_tv_chambolle. + """ + xp = self._xp + + if self._object_type == "complex": + updated_object = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + current_object_sum = xp.sum(current_object) + + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if padding is not None: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (padding, padding) + + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, + ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() + + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E + E_previous = E + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 + + if padding is not None: + for ax in range(len(ndim)): + slices = array_slice( + ndim[ax], current_object.ndim, padding, -padding + ) + updated_object = updated_object[slices] + + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) + + return updated_object + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """ObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + + +class Object2p5DConstraintsMixin: + """ + Mixin class for object constraints unique to 2.5D objects. + Overwrites ObjectNDConstraintsMixin. + """ + + def _object_denoise_tv_pylops(self, current_object, weights, iterations, z_padding): + """ + Performs second order TV denoising along x and y, and first order along z + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z_weight, r_weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + z_padding: int + Symmetric padding around the first axis + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[ + z_padding:-z_padding + ] + + return current_object_tv + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma, z_padding + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + z_padding: int + Symmetric padding around the first axis + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + # zero pad at top and bottom slice + pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) + current_object = xp.pad(current_object, pad_width=pad_width, mode="constant") + + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[z_padding:-z_padding] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + identical_slices, + kz_regularization_filter, + kz_regularization_gamma, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object2p5DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, + kz_regularization_gamma, + z_padding=1, + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + z_padding=1, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + padding=tv_denoise_pad_chambolle, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + + +class Object3DConstraintsMixin: + """ + Mixin class for object constraints unique to 3D objects. + Overwrites ObjectNDConstraintsMixin and Object2p5DConstraintsMixin. + """ + + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + xyz_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(0, 1, 2), edge=False, kind="backward" + ) + + l1_regs = [xyz_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + Butterworth filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qra = xp.sqrt(qza**2 + qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object3DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object=False + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # Positivity + if object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + + +class ProbeConstraintsMixin: + """ + Mixin class for regularizations applicable to a single probe. + """ + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + + probe_intensity = xp.abs(current_probe) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_amplitude_constraint( + self, current_probe, relative_radius, relative_width + ): + """ + Ptychographic top-hat filtering of probe. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + erf = self._scipy.special.erf + + probe_intensity = xp.abs(current_probe) ** 2 + current_probe_sum = xp.sum(probe_intensity) + + X = xp.fft.fftfreq(current_probe.shape[0])[:, None] + Y = xp.fft.fftfreq(current_probe.shape[1])[None] + r = xp.hypot(X, Y) - relative_radius + + sigma = np.sqrt(np.pi) / relative_width + tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) + + updated_probe = current_probe * tophat_mask + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe * normalization + + def _probe_fourier_amplitude_constraint( + self, + current_probe, + width_max_pixels, + enforce_constant_intensity, + ): + """ + Ptychographic top-hat filtering of Fourier probe. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + threshold: np.ndarray + Threshold value for current probe fourier mask. Value should + be between 0 and 1, where 1 uses the maximum amplitude to threshold. + relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + asnumpy = self._asnumpy + + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_fft = xp.fft.fft2(current_probe) + + updated_probe_fft, _, _, _ = regularize_probe_amplitude( + asnumpy(current_probe_fft), + width_max_pixels=width_max_pixels, + nearest_angular_neighbor_averaging=5, + enforce_constant_intensity=enforce_constant_intensity, + corner_centered=True, + ) + + updated_probe_fft = xp.asarray(updated_probe_fft) + updated_probe = xp.fft.ifft2(updated_probe_fft) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe * normalization + + def _probe_aperture_constraint( + self, + current_probe, + initial_probe_aperture, + ): + """ + Ptychographic constraint to fix Fourier amplitude to initial aperture. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) + + updated_probe = xp.fft.ifft2( + xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture + ) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe * normalization + + def _probe_aberration_fitting_constraint( + self, + current_probe, + max_angular_order, + max_radial_order, + remove_initial_probe_aberrations, + use_scikit_image, + ): + """ + Ptychographic probe smoothing constraint. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + max_angular_order: bool + Max angular order of probe aberrations basis functions + max_radial_order: bool + Max radial order of probe aberrations basis functions + remove_initial_probe_aberrations: bool, optional + If true, initial probe aberrations are removed before fitting + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + + xp = self._xp + + fourier_probe = xp.fft.fft2(current_probe) + if remove_initial_probe_aberrations: + fourier_probe *= xp.conj(self._known_aberrations_array) + + fourier_probe_abs = xp.abs(fourier_probe) + sampling = self.sampling + energy = self._energy + + fitted_angle, _ = fit_aberration_surface( + fourier_probe, + sampling, + energy, + max_angular_order, + max_radial_order, + use_scikit_image, + xp=xp, + ) + + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) + if remove_initial_probe_aberrations: + fourier_probe *= self._known_aberrations_array + + current_probe = xp.fft.ifft2(fourier_probe) + + return current_probe + + def _probe_constraints( + self, + current_probe, + fix_probe_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + **kwargs, + ): + """ProbeConstraints wrapper function""" + + # CoM corner-centering + if fix_probe_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( + current_probe, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( + current_probe, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( + current_probe, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + return current_probe + + +class ProbeMixedConstraintsMixin: + """ + Mixin class for regularizations unique to mixed probes. + Overwrites ProbeConstraintsMixin. + """ + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _probe_constraints( + self, + current_probe, + fix_probe_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + orthogonalize_probe, + **kwargs, + ): + """ProbeMixedConstraints wrapper function""" + + # CoM corner-centering + if fix_probe_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_aberration_fitting_constraint( + current_probe[probe_idx], + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe[0] = self._probe_aperture_constraint( + current_probe[0], + initial_probe_aperture[0], + ) + elif constrain_probe_fourier_amplitude: + current_probe[0] = self._probe_fourier_amplitude_constraint( + current_probe[0], + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_amplitude_constraint( + current_probe[probe_idx], + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + # Probe orthogonalization + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + return current_probe + + +class PositionsConstraintsMixin: + """ + Mixin class for probe positions constraints. + """ + + def _positions_center_of_mass_constraint( + self, current_positions, initial_positions_com + ): + """ + Ptychographic position center of mass constraint. + Additionally updates vectorized indices used in _overlap_projection. + + Parameters + ---------- + current_positions: np.ndarray + Current positions estimate + + Returns + -------- + constrained_positions: np.ndarray + CoM constrained positions estimate + """ + current_positions -= current_positions.mean(0) - initial_positions_com + + return current_positions + + def _positions_affine_transformation_constraint( + self, initial_positions, current_positions + ): + """ + Constrains the updated positions to be an affine transformation of the initial scan positions, + composing of two scale factors, a shear, and a rotation angle. + + Uses RANSAC to estimate the global transformation robustly. + Stores the AffineTransformation in self._tf. + + Parameters + ---------- + initial_positions: np.ndarray + Initial scan positions + current_positions: np.ndarray + Current positions estimate + + Returns + ------- + constrained_positions: np.ndarray + Affine-transform constrained positions estimate + """ + + xp_storage = self._xp_storage + initial_positions_com = initial_positions.mean(0) + + tf, _ = estimate_global_transformation_ransac( + positions0=initial_positions, + positions1=current_positions, + origin=initial_positions_com, + translation_allowed=True, + min_sample=initial_positions.shape[0] // 10, + xp=xp_storage, + ) + + current_positions = tf( + initial_positions, origin=initial_positions_com, xp=xp_storage + ) + self._tf = tf + + return current_positions + + def _positions_constraints( + self, + current_positions, + initial_positions, + fix_positions, + fix_positions_com, + global_affine_transformation, + **kwargs, + ): + """PositionsConstraints wrapper function""" + + if not fix_positions: + if not fix_positions_com: + current_positions = self._positions_center_of_mass_constraint( + current_positions, initial_positions.mean(0) + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + initial_positions, current_positions + ) + + return current_positions diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py new file mode 100644 index 000000000..6ab349f30 --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -0,0 +1,3364 @@ +import warnings +from typing import Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.process.phase.utils import ( + AffineTransform, + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + partition_list, + rotate_point, + spatial_frequencies, + vectorized_bilinear_resample, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex +from scipy.ndimage import gaussian_filter, rotate + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +class ObjectNDMethodsMixin: + """ + Mixin class for object methods applicable to 2D,2.5D, and 3D objects. + """ + + def _initialize_object( + self, + initial_object, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + + def _crop_rotate_object_fov( + self, + array, + positions_px=None, + padding=0, + ): + """ + Crops and rotated object to FOV bounded by current pixel positions. + + Parameters + ---------- + array: np.ndarray + Object array to crop and rotate. Only operates on numpy arrays for compatibility. + padding: int, optional + Optional padding outside pixel positions + + Returns + cropped_rotated_array: np.ndarray + Cropped and rotated object array + """ + + asnumpy = self._asnumpy + + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + if positions_px is None: + positions_px = asnumpy(self._positions_px) + else: + positions_px = asnumpy(positions_px) + + tf = AffineTransform(angle=angle) + rotated_points = tf(positions_px, origin=positions_px.mean(0), xp=np) + + min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") + min_x = min_x if min_x > 0 else 0 + min_y = min_y if min_y > 0 else 0 + max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") + + rotated_array = rotate( + asnumpy(array), np.rad2deg(-angle), order=1, reshape=False, axes=(-2, -1) + )[..., min_x:max_x, min_y:max_y] + + if self._rotation_best_transpose: + rotated_array = rotated_array.swapaxes(-2, -1) + + return rotated_array + + def _return_projected_cropped_potential( + self, + obj=None, + return_kwargs=False, + **kwargs, + ): + """Utility function to accommodate multiple classes""" + if obj is None: + obj = self.object_cropped + else: + obj = self._crop_rotate_object_fov(obj) + + if np.iscomplexobj(obj): + obj = np.angle(obj) + + if return_kwargs: + return obj, kwargs + else: + return obj + + def _return_object_fft( + self, + obj=None, + apply_hanning_window=False, + **kwargs, + ): + """ + Returns absolute value of obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + xp = self._xp + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + if np.iscomplexobj(obj): + obj = xp.angle(obj) + + obj = self._crop_rotate_object_fov(asnumpy(obj)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + def show_object_fft( + self, + obj=None, + apply_hanning_window=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot FFT of reconstructed object + + Parameters + ---------- + obj: complex array, optional + If None is specified, uses the `object_fft` property + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is object FFT sampling + """ + + object_fft = self._return_object_fft( + obj, apply_hanning_window=apply_hanning_window, **kwargs + ) + + if pixelsize is None: + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + figsize = kwargs.pop("figsize", (4, 4)) + cmap = kwargs.pop("cmap", "magma") + ticks = kwargs.pop("ticks", False) + vmin = kwargs.pop("vmin", 0.001) + vmax = kwargs.pop("vmax", 0.999) + + # remove additional 3D FFT parameters before passing to show + kwargs.pop("orientation_matrix", None) + kwargs.pop("vertical_lims", None) + kwargs.pop("horizontal_lims", None) + + show( + object_fft, + figsize=figsize, + cmap=cmap, + scalebar=scalebar, + pixelsize=pixelsize, + ticks=ticks, + pixelunits=pixelunits, + vmin=vmin, + vmax=vmax, + aspect=object_fft.shape[1] / object_fft.shape[0], + **kwargs, + ) + + def _reset_reconstruction( + self, + store_iterations, + reset, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._object_type = self._object_type_initial + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + self._exit_waves = None + + @property + def object_fft(self): + """Fourier transform of current object estimate""" + + if not hasattr(self, "_object"): + return None + + return self._return_object_fft(self._object) + + @property + def object_cropped(self): + """Cropped and rotated object""" + + return self._crop_rotate_object_fov(self._object) + + +class Object2p5DMethodsMixin: + """ + Mixin class for object methods unique to 2.5D objects. + Overwrites ObjectNDMethodsMixin. + """ + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + theta_x: float = None, + theta_y: float = None, + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + theta_x: float, optional + x tilt of propagator in mrad + theta_y: float, optional + y tilt of propagator in mrad + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + if theta_x is not None: + propagators[i] *= xp.exp( + 1.0j * (-2 * kx[:, None] * np.pi * dz * np.tan(theta_x / 1e3)) + ) + + if theta_y is not None: + propagators[i] *= xp.exp( + 1.0j * (-2 * ky[None] * np.pi * dz * np.tan(theta_y / 1e3)) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def _initialize_object( + self, + initial_object, + num_slices, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((num_slices, p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((num_slices, p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + + def _return_projected_cropped_potential( + self, + obj=None, + return_kwargs=False, + **kwargs, + ): + """Utility function to accommodate multiple classes""" + + if obj is None: + obj = self.object_cropped + else: + obj = self._crop_rotate_object_fov(obj) + + if np.iscomplexobj(obj): + obj = np.angle(obj).sum(0) + else: + obj = obj.sum(0) + + if return_kwargs: + return obj, kwargs + else: + return obj + + def _return_object_fft( + self, + obj=None, + apply_hanning_window=False, + **kwargs, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + xp = self._xp + + if obj is None: + obj = self._object + + if np.iscomplexobj(obj): + obj = xp.angle(obj) + + obj = self._crop_rotate_object_fov(obj.sum(axis=0)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + def show_depth_section( + self, + ptA: Tuple[float, float], + ptB: Tuple[float, float], + aspect_ratio: float = "auto", + plot_line_profile: bool = False, + ms_object=None, + specify_calibrated: bool = True, + gaussian_filter_sigma: float = None, + cbar: bool = True, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + ---------- + ptA: Tuple[float,float] + Starting point (x1,y1) for line profile depth section + If either is None, assumed to be array start. + Specified in Angstroms unless specify_calibrated is False + ptB: Tuple[float,float] + End point (x2,y2) for line profile depth section + If either is None, assumed to be array end. + Specified in Angstroms unless specify_calibrated is False + aspect_ratio: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + ms_object: np.array + Object to plot slices of. If None, uses current object + specify_calibrated: bool (optional) + If False, ptA and ptB points specified in pixels instead of Angstroms + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + cbar: bool, optional + If True, displays a colorbar + """ + if ms_object is None: + ms_object = self.object_cropped + + if np.iscomplexobj(ms_object): + ms_object = np.angle(ms_object) + + x1, y1 = ptA + x2, y2 = ptB + + if x1 is None: + x1 = 0 + if y1 is None: + y1 = 0 + if x2 is None: + x2 = self.sampling[0] * ms_object.shape[1] + if y2 is None: + y2 = self.sampling[1] * ms_object.shape[2] + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + x1, x2 = np.array([x1, x2]).clip(0, ms_object.shape[1]) + y1, y2 = np.array([y1, y2]).clip(0, ms_object.shape[2]) + + angle = np.arctan2(x2 - x1, y2 - y1) + + x0 = ms_object.shape[1] / 2 + y0 = ms_object.shape[2] / 2 + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_object, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + -int(x1_0), + axis=1, + ) + + if gaussian_filter_sigma is not None: + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + y1_0, y2_0 = ( + np.array([y1_0, y2_0]).astype("int").clip(0, rotated_object.shape[2]) + ) + plot_im = rotated_object[:, 0, y1_0:y2_0] + + # Plotting + if plot_line_profile: + ncols = 2 + else: + ncols = 1 + col_index = 0 + + spec = GridSpec(ncols=ncols, nrows=1, wspace=0.15) + + figsize = kwargs.pop("figsize", (4 * ncols, 4)) + fig = plt.figure(figsize=figsize) + cmap = kwargs.pop("cmap", "magma") + + # Line profile + if plot_line_profile: + ax = fig.add_subplot(spec[0, col_index]) + + extent_line = [ + 0, + self.sampling[1] * ms_object.shape[2], + self.sampling[0] * ms_object.shape[1], + 0, + ] + + ax.imshow(ms_object.sum(0), cmap="gray", extent=extent_line) + + ax.plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile location") + col_index += 1 + + # Main visualization + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + ax = fig.add_subplot(spec[0, col_index]) + im = ax.imshow(plot_im, cmap=cmap, extent=extent) + + if aspect_ratio is not None: + if aspect_ratio == "auto": + aspect_ratio = extent[1] / extent[2] + if plot_line_profile: + aspect_ratio *= extent_line[2] / extent_line[1] + + ax.set_aspect(aspect_ratio) + cbar = False + + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + spec.tight_layout(fig) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + show_fft: bool = False, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) + + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + + if common_color_scale: + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + rotated_object, vmin, vmax = return_scaled_histogram_ordering( + rotated_object, vmin=vmin, vmax=vmax + ) + else: + vmin = None + vmax = None + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + +class Object3DMethodsMixin: + """ + Mixin class for object methods unique to 3D objects. + Overwrites ObjectNDMethodsMixin and Object2p5DMethodsMixin. + """ + + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + + def _project_sliced_object(self, array: np.ndarray, output_z): + """ + Projects voxel-sliced object. + + Parameters + ---------- + array: np.ndarray + 3D array to project + output_z: int + Output_dimension to project array to. + + Returns + ------- + projected_array: np.ndarray + projected array + """ + xp = self._xp + input_z = array.shape[0] + + voxels_per_slice = np.ceil(input_z / output_z).astype("int") + pad_size = voxels_per_slice * output_z - input_z + + padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) + + return xp.sum( + padded_array.reshape( + ( + -1, + voxels_per_slice, + ) + + array.shape[1:] + ), + axis=1, + ) + + def _expand_sliced_object(self, array: np.ndarray, output_z): + """ + Expands supersliced object. + + Parameters + ---------- + array: np.ndarray + 3D array to expand + output_z: int + Output_dimension to expand array to. + + Returns + ------- + expanded_array: np.ndarray + expanded array + """ + xp = self._xp + input_z = array.shape[0] + + voxels_per_slice = np.ceil(output_z / input_z).astype("int") + remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) + + voxels_in_slice = xp.repeat(voxels_per_slice, input_z) + voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice + + normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] + return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + order=3, + ): + """ """ + + xp = self._xp + affine_transform = self._scipy.ndimage.affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=order) + + return volume + + def _initialize_object( + self, + initial_object, + positions_px, + object_type, + main_tilt_axis="vertical", + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + + if main_tilt_axis == "vertical": + _object = xp.zeros((q, p, q), dtype=xp.float32) + elif main_tilt_axis == "horizontal": + _object = xp.zeros((p, p, q), dtype=xp.float32) + else: + _object = xp.zeros((max(p, q), p, q), dtype=xp.float32) + else: + _object = xp.asarray(initial_object, dtype=xp.float32) + + return _object + + def _return_projected_cropped_potential( + self, + obj=None, + return_kwargs=False, + **kwargs, + ): + """Utility function to accommodate multiple classes""" + + asnumpy = self._asnumpy + + rot_matrix = kwargs.pop("orientation_matrix", None) + v_lims = kwargs.pop("vertical_lims", (None, None)) + h_lims = kwargs.pop("horizontal_lims", (None, None)) + + if obj is None: + obj = self._object + + if rot_matrix is not None: + obj = self._rotate_zxy_volume( + obj, + rot_matrix=rot_matrix, + ) + + start_v, end_v = v_lims + start_h, end_h = h_lims + obj = asnumpy(obj.sum(0)[start_v:end_v, start_h:end_h]) + + if return_kwargs: + return obj, kwargs + else: + return obj + + def _return_object_fft( + self, + obj=None, + apply_hanning_window=False, + orientation_matrix=None, + vertical_lims: Tuple[int, int] = (None, None), + horizontal_lims: Tuple[int, int] = (None, None), + **kwargs, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + orientation_matrix: np.ndarray, optional + orientation matrix to rotate zone-axis + vertical_lims: tuple(int,int), optional + min/max vertical indices + horizontal_lims: tuple(int,int), optional + min/max horizontal indices + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + + xp = self._xp + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + else: + obj = xp.asarray(obj, dtype=xp.float32) + + if orientation_matrix is not None: + obj = self._rotate_zxy_volume( + obj, + rot_matrix=orientation_matrix, + ) + + start_v, end_v = vertical_lims + start_h, end_h = horizontal_lims + obj = asnumpy(obj.sum(0)[start_v:end_v, start_h:end_h]) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + @property + def object_supersliced(self): + """Returns super-sliced object""" + return self._project_sliced_object(self._object, self._num_slices) + + +class ProbeMethodsMixin: + """ + Mixin class for probe methods applicable to a single probe. + """ + + def _initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + device = self._device + + crop_mask = self._crop_mask + region_of_interest_shape = self._region_of_interest_shape + sampling = self.sampling + energy = self._energy + rolloff = self._rolloff + polar_parameters = self._polar_parameters + + if initial_probe is None: + if vacuum_probe_intensity is not None: + semiangle_cutoff = np.inf + vacuum_probe_intensity = xp.asarray( + vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + vacuum_probe_intensity, + device=device, + ) + vacuum_probe_intensity = get_shifted_ar( + vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=device, + ) + + if crop_patterns: + vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( + region_of_interest_shape + ) + + _probe = ( + ComplexProbe( + gpts=region_of_interest_shape, + sampling=sampling, + energy=energy, + semiangle_cutoff=semiangle_cutoff, + rolloff=rolloff, + vacuum_probe_intensity=vacuum_probe_intensity, + parameters=polar_parameters, + device=device, + ) + .build() + ._array + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(_probe)) ** 2) + _probe *= xp.sqrt(mean_diffraction_intensity / probe_intensity) + + else: + if isinstance(initial_probe, ComplexProbe): + if initial_probe._gpts != region_of_interest_shape: + raise ValueError() + if hasattr(initial_probe, "_array"): + _probe = initial_probe._array + else: + initial_probe._xp = xp + _probe = initial_probe.build()._array + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(_probe)) ** 2) + _probe *= xp.sqrt(mean_diffraction_intensity / probe_intensity) + else: + _probe = xp.asarray(initial_probe, dtype=xp.complex64) + + return _probe, semiangle_cutoff + + def _return_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + ): + """ + Returns complex fourier probe shifted to center of array from + corner-centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + fourier_probe = xp.fft.fft2(probe) + + if remove_initial_probe_aberrations: + fourier_probe *= xp.conjugate(self._known_aberrations_array) + + return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) + + def _return_fourier_probe_from_centered_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + ): + """ + Returns complex fourier probe shifted to center of array from + centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + return self._return_fourier_probe( + xp.fft.ifftshift(probe, axes=(-2, -1)), + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + + def _return_centered_probe( + self, + probe=None, + ): + """ + Returns complex probe centered in middle of the array. + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + + Returns + ------- + centered_probe: np.ndarray + Center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + return xp.fft.fftshift(probe, axes=(-2, -1)) + + def _return_probe_intensities(self, probe): + """ + Returns probe intensities summing up to 1. + """ + if probe is None: + probe = self.probe_centered + + intensity_arrays = np.abs(np.array(probe, ndmin=3)) ** 2 + probe_ratio = list(intensity_arrays.sum((-2, -1)) / intensity_arrays.sum()) + + return probe_ratio + + def show_probe( + self, + probe=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + W=6, + **kwargs, + ): + """ + Plot probe in real space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelsize: float, optional + default is probe reciprocal sampling + pixelunits: str, optional + units for scalebar, default is A^-1 + W: int, optional + if not None, sets the width of the image grid + """ + asnumpy = self._asnumpy + + if pixelsize is None: + pixelsize = self.sampling[1] + if pixelunits is None: + pixelunits = r"$\AA$" + + intensities = self._return_probe_intensities(probe) + title = [ + f"Probe {iter} intensity: {ratio*100:.1f}%" + for iter, ratio in enumerate(intensities) + ] + + axsize = kwargs.pop("axsize", (4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + ticks = kwargs.pop("ticks", False) + title = kwargs.pop("title", title if len(title) > 1 else title[0]) + + if probe is None: + probe = list(np.array(self.probe_centered, ndmin=3)) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_centered_probe( + pr, + ) + ) + for pr in probe + ] + + probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] + + show_complex( + probe, + cbar=cbar, + axsize=axsize, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=ticks, + chroma_boost=chroma_boost, + title=title, + **kwargs, + ) + + def show_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + W=6, + **kwargs, + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelsize: float, optional + default is probe reciprocal sampling + pixelunits: str, optional + units for scalebar, default is A^-1 + W: int, optional + if not None, sets the width of the image grid + """ + asnumpy = self._asnumpy + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + intensities = self._return_probe_intensities(probe) + title = [ + f"Probe {iter} intensity: {ratio*100:.1f}%" + for iter, ratio in enumerate(intensities) + ] + + axsize = kwargs.pop("axsize", (4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + ticks = kwargs.pop("ticks", False) + title = kwargs.pop("title", title if len(title) > 1 else title[0]) + + if probe is None: + if remove_initial_probe_aberrations: + probe = self.probe_fourier_residual + else: + probe = self.probe_fourier + probe = list(np.array(probe, ndmin=3)) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] + + probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] + + show_complex( + probe, + cbar=cbar, + axsize=axsize, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=ticks, + chroma_boost=chroma_boost, + title=title, + **kwargs, + ) + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + return xp.asarray(probe) + else: + if not hasattr(self, "_probe"): + return None + + return self._probe + + @property + def probe_fourier(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_fourier_probe(self._probe)) + + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy( + self._return_fourier_probe( + self._probe, remove_initial_probe_aberrations=True + ) + ) + + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_centered_probe(self._probe)) + + +class ProbeMixedMethodsMixin: + """ + Mixin class for probe methods unique to mixed probes. + Overwrites ProbeMethodsMixin. + """ + + def _initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ): + """ """ + + # explicit read-only self attributes up-front + xp = self._xp + num_probes = self._num_probes + region_of_interest_shape = self._region_of_interest_shape + + if initial_probe is None or isinstance(initial_probe, ComplexProbe): + # calls ProbeMethodsMixin for first probe + # annoyingly can't use super() as Mixins are defined right->left + # but MRO is defined left->right.. + _probe, semiangle_cutoff = ProbeMethodsMixin._initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ) + + sx, sy = region_of_interest_shape + _probes = xp.zeros((num_probes, sx, sy), dtype=xp.complex64) + _probes[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + _probes[i_probe] = ( + _probes[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + else: + _probes = xp.asarray(initial_probe, dtype=xp.complex64) + + return _probes, semiangle_cutoff + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + return xp.asarray(probe[0]) + else: + if not hasattr(self, "_probe"): + return None + + return self._probe[0] + + +class ObjectNDProbeMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. + """ + + def _return_shifted_probes(self, current_probe, positions_px_fractional): + """Simple utility to de-duplicate _overlap_projection""" + + xp = self._xp + shifted_probes = fft_shift(current_probe, positions_px_fractional, xp) + return shifted_probes + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = current_object[ + vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + overlap = shifted_probes * object_patches + + return shifted_probes, object_patches, overlap + + def _return_farfield_amplitudes(self, fourier_overlap): + """Small utility to de-duplicate mixed-state Fourier projection.""" + + xp = self._xp + return xp.abs(fourier_overlap) + + def _gradient_descent_fourier_projection(self, amplitudes, overlap): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + + fourier_overlap = xp.fft.fft2(overlap) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) + + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + exit_waves = modified_overlap - overlap + + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + exit_waves = vectorized_bilinear_resample( + exit_waves, output_size=self._region_of_interest_shape, xp=xp + ) + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + factor_to_be_projected = vectorized_bilinear_resample( + factor_to_be_projected, output_size=amplitudes.shape[-2:], xp=xp + ) + + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + fourier_projected_factor = amplitudes * xp.exp( + 1j * xp.angle(fourier_projected_factor) + ) + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + projected_factor = vectorized_bilinear_resample( + projected_factor, output_size=self._region_of_interest_shape, xp=xp + ) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + current_probe, + positions_px_fractional, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + object * probe overlap + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + shifted_probes = self._return_shifted_probes( + current_probe, positions_px_fractional + ) + + shifted_probes, object_patches, overlap = self._overlap_projection( + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + overlap, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, overlap + ) + + return shifted_probes, object_patches, overlap, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + object_update: np.ndarray + Updated object estimate + probe_update: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ), + positions_px, + ) + * probe_normalization + ) + else: + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves, positions_px + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ), + positions_px, + ) + * probe_normalization + ) + else: + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves, + positions_px, + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes, + current_positions, + current_positions_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + shifted_probes:np.ndarray + fractionally-shifted probes + overlap: np.ndarray + object * probe overlap + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + max_position_update_distance: float + Maximum allowed distance for update in A + max_position_total_distance: float + Maximum allowed distance from initial probe positions + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + storage = self._storage + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + + # unperturbed + overlap_fft = xp.fft.fft2(overlap) + overlap_fft_conj = xp.conj(overlap_fft) + + estimated_intensity = self._return_farfield_amplitudes(overlap_fft) ** 2 + measured_intensity = amplitudes**2 + + # book-keeping + flat_shape = (measured_intensity.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # dx overlap projection perturbation + _, _, overlap_dx = self._overlap_projection( + current_object, + (vectorized_patch_indices_row + 1) % self._object_shape[0], + vectorized_patch_indices_col, + shifted_probes, + ) + + # dy overlap projection perturbation + _, _, overlap_dy = self._overlap_projection( + current_object, + vectorized_patch_indices_row, + (vectorized_patch_indices_col + 1) % self._object_shape[1], + shifted_probes, + ) + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap_dx = vectorized_bilinear_resample( + overlap_dx, output_size=amplitudes.shape[-2:], xp=xp + ) + overlap_dy = vectorized_bilinear_resample( + overlap_dy, output_size=amplitudes.shape[-2:], xp=xp + ) + + # partial intensities + overlap_dx_fft = overlap_fft - xp.fft.fft2(overlap_dx) + overlap_dy_fft = overlap_fft - xp.fft.fft2(overlap_dy) + partial_intensity_dx = 2 * xp.real(overlap_dx_fft * overlap_fft_conj) + partial_intensity_dy = 2 * xp.real(overlap_dy_fft * overlap_fft_conj) + + # handle mixed-state, is this correct? + if partial_intensity_dx.ndim == 4: + partial_intensity_dx = partial_intensity_dx.sum(1) + partial_intensity_dy = partial_intensity_dy.sum(1) + + partial_intensity_dx = partial_intensity_dx.reshape(flat_shape) + partial_intensity_dy = partial_intensity_dy.reshape(flat_shape) + + # least-squares fit + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + positions_update = positions_update[..., 0] * positions_step_size + + if max_position_update_distance is not None: + max_position_update_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + update_norms = xp.linalg.norm(positions_update, axis=1) + outlier_ind = update_norms > max_position_update_distance + positions_update[outlier_ind] /= ( + update_norms[outlier_ind, None] / max_position_update_distance + ) + + if max_position_total_distance is not None: + max_position_total_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + deltas = ( + xp.asarray(current_positions - current_positions_initial) + - positions_update + ) + dsts = xp.linalg.norm(deltas, axis=1) + outlier_ind = dsts > max_position_total_distance + positions_update[outlier_ind] = 0 + + current_positions -= copy_to_device(positions_update, storage) + + return current_positions + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + errors = np.array([]) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device(self._amplitudes[start:end], device) + + # Overlaps + shifted_probes = self._return_shifted_probes( + self._probe, positions_px_fractional + ) + _, _, overlap = self._overlap_projection( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes_device.shape[-2:], xp=xp + ) + + fourier_overlap = xp.fft.fft2(overlap) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes_device - farfield_amplitudes) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + +class Object2p5DProbeMethodsMixin: + """ + Mixin class for methods unique to 2.5D objects using a single probe. + Overwrites ObjectNDProbeMethodsMixin. + """ + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes_in, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + object_patches = current_object[ + :, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + shifted_probes = xp.empty_like(object_patches) + shifted_probes[0] = shifted_probes_in + + for s in range(self._num_slices): + # transmit + overlap = object_patches[s] * shifted_probes[s] + + # propagate + if s + 1 < self._num_slices: + shifted_probes[s + 1] = self._propagate_array( + overlap, self._propagator_arrays[s] + ) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(probe) ** 2, + positions_px, + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object[s] += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves), + positions_px, + ) + * probe_normalization + ) + else: + current_object[s] += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.conj(probe) * exit_waves, positions_px + ) + * probe_normalization + ) + + # back-transmit + exit_waves *= xp.conj(obj) + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(probe) ** 2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object[s] = ( + self._sum_overlapping_patches_bincounts( + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy), + positions_px, + ) + * probe_normalization + ) + else: + current_object[s] = ( + self._sum_overlapping_patches_bincounts( + xp.conj(probe) * exit_waves_copy, + positions_px, + ) + * probe_normalization + ) + + # back-transmit + exit_waves_copy *= xp.conj(obj) + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def show_transmitted_probe( + self, + max_batch_size=None, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + max_batch_size: int, optional + Max number of probes to calculate at once + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + kwargs: + Passed to show_complex + """ + + xp = self._xp + xp_storage = self._xp_storage + asnumpy = self._asnumpy + + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + mean_transmitted = xp.zeros_like(self._probe) + intensities_compare = [np.inf, 0] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + # overlaps + shifted_probes = self._return_shifted_probes( + self._probe, positions_px_fractional + ) + _, _, overlap = self._overlap_projection( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) + + # store relevant arrays + mean_transmitted += overlap.sum(0) + + intensities = xp.sum(xp.abs(overlap) ** 2, axis=(-2, -1)) + min_intensity = intensities.min() + max_intensity = intensities.max() + + if min_intensity < intensities_compare[0]: + min_intensity_transmitted = overlap[xp.argmin(intensities)] + intensities_compare[0] = min_intensity + + if max_intensity > intensities_compare[1]: + max_intensity_transmitted = overlap[xp.argmax(intensities)] + intensities_compare[1] = max_intensity + + mean_transmitted /= self._num_diffraction_patterns + + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean transmitted probe", + "Min-intensity transmitted probe", + "Max-intensity transmitted probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean transmitted Fourier probe", + "Min-intensity transmitted Fourier probe", + "Max-intensity transmitted Fourier probe", + ] + + title = kwargs.get("title", title) + ticks = kwargs.get("ticks", False) + axsize = kwargs.get("axsize", (4, 4)) + + show_complex( + probes, + title=title, + ticks=ticks, + axsize=axsize, + **kwargs, + ) + + self.clear_device_mem(self._device, self._clear_fft_cache) + + +class ObjectNDProbeMixedMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using mixed probes. + Overwrites ObjectNDProbeMethodsMixin. + """ + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = current_object[ + vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) + + return shifted_probes, object_patches, overlap + + def _return_farfield_amplitudes(self, fourier_overlap): + """Small utility to de-duplicate mixed-state Fourier projection.""" + + xp = self._xp + return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + def _gradient_descent_fourier_projection(self, amplitudes, overlap): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + overlap = vectorized_bilinear_resample( + overlap, output_size=amplitudes.shape[-2:], xp=xp + ) + + fourier_overlap = xp.fft.fft2(overlap) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap + modified_overlap = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_overlap - overlap + + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + exit_waves = vectorized_bilinear_resample( + exit_waves, output_size=self._region_of_interest_shape, xp=xp + ) + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + + # resample to match data, note: this needs to happen in real-space + if self._resample_exit_waves: + factor_to_be_projected = vectorized_bilinear_resample( + factor_to_be_projected, output_size=amplitudes.shape[-2:], xp=xp + ) + + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes + + fourier_projected_factor *= amplitude_modification[:, None] + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + projected_factor = vectorized_bilinear_resample( + projected_factor, output_size=self._region_of_interest_shape, xp=xp + ) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = xp.zeros_like(current_object) + object_update = xp.zeros_like(current_object) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes[:, i_probe]) ** 2, + positions_px, + ) + if self._object_type == "potential": + object_update += step_size * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes[:, i_probe]) + * exit_waves[:, i_probe] + ), + positions_px, + ) + else: + object_update += step_size * self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe], + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object += object_update * probe_normalization + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = xp.zeros_like(current_object) + current_object = xp.zeros_like(current_object) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes[:, i_probe]) ** 2, + positions_px, + ) + if self._object_type == "potential": + current_object += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes[:, i_probe]) + * exit_waves[:, i_probe] + ), + positions_px, + ) + else: + current_object += self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe], + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object *= probe_normalization + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + +class Object2p5DProbeMixedMethodsMixin: + """ + Mixin class for methods unique to 2.5D objects using mixed probes. + Overwrites ObjectNDProbeMethodsMixin and ObjectNDProbeMixedMethodsMixin. + """ + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes_in, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + object_patches = current_object[ + :, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + num_probe_positions = object_patches.shape[1] + + shifted_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + + shifted_probes = xp.empty(shifted_shape, dtype=object_patches.dtype) + shifted_probes[0] = shifted_probes_in + + for s in range(self._num_slices): + # transmit + overlap = xp.expand_dims(object_patches[s], axis=1) * shifted_probes[s] + + # propagate + if s + 1 < self._num_slices: + shifted_probes[s + 1] = self._propagate_array( + overlap, self._propagator_arrays[s] + ) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2, + positions_px, + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ), + positions_px, + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe], + positions_px, + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2, + positions_px, + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ), + positions_px, + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe], + positions_px, + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims(xp.conj(obj), axis=1) + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def show_transmitted_probe( + self, + **kwargs, + ): + raise NotImplementedError() + + +class MultipleMeasurementsMethodsMixin: + """ + Mixin class for methods unique to classes with multiple measurements. + Overwrites various Mixins. + """ + + def _reset_reconstruction( + self, + store_iterations, + reset, + use_projection_scheme, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probes_all = [pr.copy() for pr in self._probes_all_initial] + self._positions_px_all = self._positions_px_initial_all.copy() + self._object_type = self._object_type_initial + + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + _probes = [xp.asarray(pr) for pr in probe] + else: + if not hasattr(self, "_probes_all"): + return None + _probes = self._probes_all + + probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) + + for pr in _probes: + probe += pr + + return probe / len(_probes) + + def _return_average_positions( + self, positions=None, cum_probes_per_measurement=None + ): + """Average positions estimate""" + xp_storage = self._xp_storage + + if positions is not None: + _pos = xp_storage.asarray(positions) + else: + if not hasattr(self, "_positions_px_all"): + return None + _pos = self._positions_px_all + + if cum_probes_per_measurement is None: + cum_probes_per_measurement = self._cum_probes_per_measurement + + num_probes_per_measurement = np.diff(cum_probes_per_measurement) + num_measurements = len(num_probes_per_measurement) + + if np.any(num_probes_per_measurement != num_probes_per_measurement[0]): + return None + + avg_positions = xp_storage.zeros( + (num_probes_per_measurement[0], 2), dtype=xp_storage.float32 + ) + + for index in range(num_measurements): + start_idx = cum_probes_per_measurement[index] + end_idx = cum_probes_per_measurement[index + 1] + avg_positions += _pos[start_idx:end_idx] + + return avg_positions / num_measurements + + def _return_self_consistency_errors( + self, + **kwargs, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + @property + def probe_fourier(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [asnumpy(self._return_fourier_probe(pr)) for pr in self._probes_all] + + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [ + asnumpy( + self._return_fourier_probe(pr, remove_initial_probe_aberrations=True) + ) + for pr in self._probes_all + ] + + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [asnumpy(self._return_centered_probe(pr)) for pr in self._probes_all] + + @property + def positions(self): + """Probe positions [A]""" + + if self.angular_sampling is None: + return None + + asnumpy = self._asnumpy + positions_all = [] + + for index in range(self._num_measurements): + start_idx = self._cum_probes_per_measurement[index] + end_idx = self._cum_probes_per_measurement[index + 1] + positions = self._positions_px_all[start_idx:end_idx].copy() + positions[:, 0] *= self.sampling[0] + positions[:, 1] *= self.sampling[1] + positions_all.append(asnumpy(positions)) + + return np.asarray(positions_all) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py new file mode 100644 index 000000000..b4a29fa5d --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -0,0 +1,1265 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely joint ptychographic tomography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, + Object3DMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class PtychographicTomography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object3DConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + Object3DMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Ptychographic Tomography Reconstruction Class. + + List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (Px,Py,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py,Py) is the padded-object electrostatic potential volume, + where x-axis is the tilt. + + Parameters + ---------- + datacube: List of DataCubes + Input list of 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + num_slices: int + Number of super-slices to use in the forward model + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py,Py) + If None, initialized to 1.0 + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: list of np.ndarray, optional + Probe positions in Å for each diffraction intensity per tilt + If None, initialized to a grid scan centered along tilt axis + verbose: bool, optional + If True, class methods will inherit this and print additional information + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_slices", + "_tilt_orientation_matrices", + "_num_measurements", + ) + + def __init__( + self, + energy: float, + num_slices: int, + tilt_orientation_matrices: Sequence[np.ndarray], + datacube: Sequence[DataCube] = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + object_type: str = "potential", + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: Sequence[np.ndarray] = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "ptychographic-tomography_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + num_tilts = len(tilt_orientation_matrices) + if initial_scan_positions is None: + initial_scan_positions = [None] * num_tilts + + if object_type != "potential": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_slices = num_slices + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) + self._num_measurements = num_tilts + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_probe_overlaps: bool = True, + rotation_real_space_degrees: float = None, + diffraction_patterns_rotate_degrees: float = None, + diffraction_patterns_transpose: bool = None, + force_com_shifts: Sequence[float] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + main_tilt_axis: str = "vertical", + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + + Additionally, it initializes an (Px,Py, Py) array of 1.0 + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + rotation_real_space_degrees: float (degrees), optional + In plane rotation around z axis between x axis and tilt axis in + real space (forced to be in xy plane) + diffraction_patterns_rotate_degrees: float, optional + Relative rotation angle between real and reciprocal space + diffraction_patterns_transpose: bool, optional + Whether diffraction intensities need to be transposed. + force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. One tuple per tilt. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + If True, crop patterns to avoid wrap around of patterns when centering + main_tilt_axis: str + The default, 'vertical' (first scan dimension), results in object size (q,p,q), + 'horizontal' (second scan dimension) results in object size (p,p,q), + any other value (e.g. None) results in object size (max(p,q),p,q). + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: OverlapTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) + self._rotation_best_transpose = diffraction_patterns_transpose + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="tilt", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first tilt + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + ) + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + main_tilt_axis, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + self._num_voxels = self._object.shape[0] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + if main_tilt_axis == "vertical": + thickness = self._object_shape[1] * self.sampling[1] + elif main_tilt_axis == "horizontal": + thickness = self._object_shape[0] * self.sampling[0] + else: + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + + self._slice_thicknesses = np.tile( + thickness / self._num_slices, self._num_slices - 1 + ) + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + if object_fov_mask is None: + probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + rot_matrix = self._tilt_orientation_matrices[index] + + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + rot_matrix @ old_rot_matrix.T, + ) + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + num_diffraction_patterns = idx_end - idx_start + shuffled_indices = np.arange(idx_start, idx_end) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + shifted_probes = fft_shift( + self._probes_all[index], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix + + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) + + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_3D_blurred = gaussian_filter(probe_overlap_3D, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max() + ) + + else: + self._object_fov_mask = np.asarray(object_fov_mask) + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probes_all[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, + chroma_boost=chroma_boost, + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + tv_denoise: bool = True, + tv_denoise_weights: float = None, + tv_denoise_inner_iter=40, + collective_measurement_updates: bool = True, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + object_positivity: bool, optional + If True, forces object to be positive + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_measurement_updates: bool + if True perform collective measurement updates (i.e. one per tilt) + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + + Returns + -------- + self: OverlapTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + indices = np.arange(self._num_measurements) + np.random.shuffle(indices) + + old_rot_matrix = np.eye(3) # identity + + for index in indices: + self._active_measurement_index = index + + measurement_error = 0.0 + + rot_matrix = self._tilt_orientation_matrices[ + self._active_measurement_index + ] + self._object = self._rotate_zxy_volume( + self._object, + rot_matrix @ old_rot_matrix.T, + ) + + object_sliced = self._project_sliced_object( + self._object, self._num_slices + ) + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + if not use_projection_scheme: + object_sliced_old = object_sliced.copy() + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + object_sliced, _probe = self._adjoint( + object_sliced, + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px_all[ + batch_indices + ] = self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + measurement_error += batch_error + + if not use_projection_scheme: + object_sliced -= object_sliced_old + + object_update = self._expand_sliced_object( + object_sliced, self._num_voxels + ) + + if collective_measurement_updates: + collective_object += self._rotate_zxy_volume( + object_update, rot_matrix.T + ) + else: + self._object += object_update + + old_rot_matrix = rot_matrix + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py new file mode 100644 index 000000000..58dd224cf --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -0,0 +1,844 @@ +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.process.phase.utils import AffineTransform, copy_to_device +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +class VisualizationsMixin: + """ + Mixin class for various visualization methods. + """ + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + chroma_boost = kwargs.pop("chroma_boost", 1) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + # get scaled arrays + obj, kwargs = self._return_projected_cropped_potential( + return_kwargs=True, **kwargs + ) + probe = self._return_single_probe() + + obj, vmin, vmax = return_scaled_histogram_ordering(obj, vmin, vmax) + + extent = [ + 0, + self.sampling[1] * obj.shape[1], + self.sampling[0] * obj.shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + iterations_grid: Tuple[int, int], + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + num_iter = len(self.object_iterations) + + if iterations_grid == "auto": + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + + else: + if plot_probe or plot_fourier_probe: + if iterations_grid[0] != 2: + raise ValueError() + else: + if iterations_grid[0] * iterations_grid[1] > num_iter: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + chroma_boost = kwargs.pop("chroma_boost", 1) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + # most recent errors + errors = np.array(self.error_iterations)[-num_iter:] + + max_iter = num_iter - 1 + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + probes = [ + self._return_single_probe(self.probe_iterations[idx]) + for idx in grid_range + ] + else: + total_grids = np.prod(iterations_grid) + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + + objects = [] + + for idx in grid_range: + if idx < grid_range[-1]: + obj = self._return_projected_cropped_potential( + obj=self.object_iterations[idx], + return_kwargs=False, + **kwargs, + ) + else: + obj, kwargs = self._return_projected_cropped_potential( + obj=self.object_iterations[idx], return_kwargs=True, **kwargs + ) + + objects.append(obj) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + obj, vmin_n, vmax_n = return_scaled_histogram_ordering( + objects[n], vmin=vmin, vmax=vmax + ) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin_n, + vmax=vmax_n, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} potential") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[n], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + + else: + probe_array = Complex2RGB( + asnumpy(probes[n]), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, + cbar: bool = True, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def show_updated_positions( + self, + pos=None, + initial_pos=None, + scale_arrows=1, + plot_arrow_freq=None, + plot_cropped_rotated_fov=True, + cbar=True, + verbose=True, + **kwargs, + ): + """ + Function to plot changes to probe positions during ptychography reconstruciton + + Parameters + ---------- + scale_arrows: float, optional + scaling factor to be applied on vectors prior to plt.quiver call + plot_arrow_freq: int, optional + thinning parameter to only plot a subset of probe positions + assumes grid position + verbose: bool, optional + if True, prints AffineTransformation if positions have been updated + """ + + if verbose: + if hasattr(self, "_tf"): + print(self._tf) + + asnumpy = self._asnumpy + + if pos is None: + pos = self.positions + + # handle multiple measurements + if pos.ndim == 3: + pos = pos.mean(0) + + if initial_pos is None: + initial_pos = asnumpy(self._positions_initial) + + if plot_cropped_rotated_fov: + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + initial_pos = tf(initial_pos, origin=np.mean(pos, axis=0)) + pos = tf(pos, origin=np.mean(pos, axis=0)) + + obj_shape = self.object_cropped.shape[-2:] + initial_pos_com = np.mean(initial_pos, axis=0) + center_shift = initial_pos_com - ( + np.array(obj_shape) / 2 * np.array(self.sampling) + ) + initial_pos -= center_shift + pos -= center_shift + + else: + obj_shape = self._object_shape + + if plot_arrow_freq is not None: + rshape = self._datacube.Rshape + (2,) + freq = plot_arrow_freq + + initial_pos = initial_pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + pos = pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + + deltas = pos - initial_pos + norms = np.linalg.norm(deltas, axis=1) + + extent = [ + 0, + self.sampling[1] * obj_shape[1], + self.sampling[0] * obj_shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (4, 4)) + cmap = kwargs.pop("cmap", "Reds") + + fig, ax = plt.subplots(figsize=figsize) + + im = ax.quiver( + initial_pos[:, 1], + initial_pos[:, 0], + deltas[:, 1] * scale_arrows, + deltas[:, 0] * scale_arrows, + norms, + scale_units="xy", + scale=1, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + cb = fig.colorbar(im, cax=ax_cb) + cb.set_label("Δ [A]", rotation=0, ha="left", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.set_aspect("equal") + ax.set_title("Updated probe positions") + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + xp = self._xp + device = self._device + asnumpy = self._asnumpy + gaussian_filter = self._scipy.ndimage.gaussian_filter + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + errors_xp = xp.asarray(errors) + + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + positions_px = copy_to_device(self._positions_px, device) + rotated_points = tf(positions_px, origin=positions_px.mean(0), xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + all_inds = [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + + all_weights = [ + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ] + + pix_count = xp.zeros(pixel_size, dtype=xp.float32) + pix_output = xp.zeros(pixel_size, dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + inds_1D = xp.ravel_multi_index( + inds, + pixel_output, + mode=["wrap", "wrap"], + ) + + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=pixel_size, + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * errors_xp, + minlength=pixel_size, + ) + + # reshape 1D arrays to 2D + pix_count = xp.reshape( + pix_count, + pixel_output, + ) + pix_output = xp.reshape( + pix_output, + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized squared error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], + 0, + ] + + ax.imshow( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + + self.clear_device_mem(self._device, self._clear_fft_cache) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py new file mode 100644 index 000000000..71fe65cd7 --- /dev/null +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -0,0 +1,942 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely (single-slice) ptychography. +""" + +from typing import Mapping, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM.datacube import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class SingleslicePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + datacube: DataCube + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + verbose: bool, optional + If True, class methods will inherit this and print additional information + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + device: str, optional + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = () + + def __init__( + self, + energy: float, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_padding_px: Tuple[int, int] = None, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + probe_overlap = asnumpy(probe_overlap) + + # plot probe overlaps + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered, + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.5, + pure_phase_object: bool = False, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: int + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: int + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + If not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + if object_type is not None: + self._switch_object_type(object_type) + + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + self._reset_reconstruction(store_iterations, reset) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=pure_phase_object and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object).copy()) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 79c365585..a25b7acd3 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -3,19 +3,25 @@ import matplotlib.pyplot as plt import numpy as np +from scipy.fft import dctn, idctn +from scipy.ndimage import gaussian_filter, uniform_filter1d, zoom from scipy.optimize import curve_fit try: import cupy as cp - from cupyx.scipy.fft import rfft + from cupyx.scipy.ndimage import zoom as zoom_cp + + get_array_module = cp.get_array_module except (ImportError, ModuleNotFoundError): cp = None - from scipy.fft import dstn, idstn + + def get_array_module(*args): + return np + from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from scipy.ndimage import gaussian_filter, uniform_filter1d from skimage.restoration import unwrap_phase # fmt: off @@ -404,16 +410,13 @@ def get_scattering_angles(self): def get_spatial_frequencies(self): xp = self._xp - kx, ky = spatial_frequencies(self._gpts, self._sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) + kx, ky = spatial_frequencies(self._gpts, self._sampling, xp) return kx, ky def polar_coordinates(self, x, y): """Calculate a polar grid for a given Cartesian grid.""" xp = self._xp alpha = xp.sqrt(x[:, None] ** 2 + y[None, :] ** 2) - # phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1))) # bug in abtem-legacy and py4DSTEM<=0.14.9 phi = xp.arctan2(y[None, :], x[:, None]) return alpha, phi @@ -441,7 +444,7 @@ def visualize(self, **kwargs): return self -def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): +def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float], xp=np): """ Calculate spatial frequencies of a grid. @@ -458,7 +461,7 @@ def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): """ return tuple( - np.fft.fftfreq(n, d).astype(np.float32) for n, d in zip(gpts, sampling) + xp.fft.fftfreq(n, d).astype(xp.float32) for n, d in zip(gpts, sampling) ) @@ -490,16 +493,14 @@ def fourier_translation_operator( if len(positions_shape) == 1: positions = positions[None] - kx, ky = spatial_frequencies(shape, (1.0, 1.0)) - kx = kx.reshape((1, -1, 1)) - ky = ky.reshape((1, 1, -1)) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) + kx, ky = spatial_frequencies(shape, (1.0, 1.0), xp=xp) positions = xp.asarray(positions, dtype=xp.float32) - x = positions[:, 0].reshape((-1,) + (1, 1)) - y = positions[:, 1].reshape((-1,) + (1, 1)) + x = positions[:, 0].ravel()[:, None, None] + y = positions[:, 1].ravel()[:, None, None] - result = xp.exp(-2.0j * np.pi * kx * x) * xp.exp(-2.0j * np.pi * ky * y) + result = xp.exp(-2.0j * np.pi * kx[None, :, None] * x) * xp.exp( + -2.0j * np.pi * ky[None, None, :] * y + ) if len(positions_shape) == 1: return result[0] @@ -1151,114 +1152,79 @@ def fourier_rotate_real_volume(array, angle, axes=(0, 1), xp=np): return output_arr +def array_slice(axis, ndim, start, end, step=1): + """Returns array slice along dynamic axis""" + return (slice(None),) * (axis % ndim) + (slice(start, end, step),) + + ### Divergence Projection Functions -def compute_divergence(vector_field, spacings, xp=np): +def periodic_centered_difference(array, spacing, axis, xp=np): + """Computes second-order centered difference with periodic BCs""" + return (xp.roll(array, -1, axis=axis) - xp.roll(array, 1, axis=axis)) / ( + 2 * spacing + ) + + +def compute_divergence_periodic(vector_field, spacings, xp=np): """Computes divergence of vector_field""" num_dims = len(spacings) div = xp.zeros_like(vector_field[0]) for i in range(num_dims): - div += xp.gradient(vector_field[i], spacings[i], axis=i) + div += periodic_centered_difference(vector_field[i], spacings[i], axis=i, xp=xp) return div -def compute_gradient(scalar_field, spacings, xp=np): +def compute_gradient_periodic(scalar_field, spacings, xp=np): """Computes gradient of scalar_field""" num_dims = len(spacings) grad = xp.zeros((num_dims,) + scalar_field.shape) for i in range(num_dims): - grad[i] = xp.gradient(scalar_field, spacings[i], axis=i) + grad[i] = periodic_centered_difference(scalar_field, spacings[i], axis=i, xp=xp) return grad -def array_slice(axis, ndim, start, end, step=1): - """Returns array slice along dynamic axis""" - return (slice(None),) * (axis % ndim) + (slice(start, end, step),) - - -def make_array_rfft_compatible(array_nd, axis=0, xp=np): - """Expand array to be rfft compatible""" - array_shape = np.array(array_nd.shape) - d = array_nd.ndim - n = array_shape[axis] - array_shape[axis] = (n + 1) * 2 - - dtype = array_nd.dtype - padded_array = xp.zeros(array_shape, dtype=dtype) - - padded_array[array_slice(axis, d, 1, n + 1)] = -array_nd - padded_array[array_slice(axis, d, None, -n - 1, -1)] = array_nd - - return padded_array - +def preconditioned_laplacian_periodic_3D(shape, xp=np): + """FFT eigenvalues""" + n, m, p = shape + i, j, k = xp.ogrid[0:n, 0:m, 0:p] -def dst_I(array_nd, xp=np): - """1D rfft-based DST-I""" - d = array_nd.ndim - for axis in range(d): - crop_slice = array_slice(axis, d, 1, -1) - array_nd = rfft( - make_array_rfft_compatible(array_nd, axis=axis, xp=xp), axis=axis - )[crop_slice].imag - - return array_nd - - -def idst_I(array_nd, xp=np): - """1D rfft-based iDST-I""" - scaling = np.prod((np.array(array_nd.shape) + 1) * 2) - return dst_I(array_nd, xp=xp) / scaling - - -def preconditioned_laplacian(num_exterior, spacing=1, xp=np): - """DST-I eigenvalues""" - n = num_exterior - 1 - evals_1d = 2 - 2 * xp.cos(np.pi * xp.arange(1, num_exterior) / num_exterior) - - op = ( - xp.repeat(evals_1d, n**2) - + xp.tile(evals_1d, n**2) - + xp.tile(xp.repeat(evals_1d, n), n) + op = 6 - 2 * xp.cos(2 * np.pi * i / n) * xp.cos(2 * np.pi * j / m) * xp.cos( + 2 * np.pi * k / p ) + op[0, 0, 0] = 1 # gauge invariance + return -op - return -op / spacing**2 +def preconditioned_poisson_solver_periodic_3D(rhs, gauge=None, xp=np): + """FFT based poisson solver""" + op = preconditioned_laplacian_periodic_3D(rhs.shape, xp=xp) -def preconditioned_poisson_solver(rhs_interior, spacing=1, xp=np): - """DST-I based poisson solver""" - nx, ny, nz = rhs_interior.shape - if nx != ny or nx != nz: - raise ValueError() - - op = preconditioned_laplacian(nx + 1, spacing=spacing, xp=xp) - if xp is np: - dst_rhs = dstn(rhs_interior, type=1).ravel() - dst_u = (dst_rhs / op).reshape((nx, ny, nz)) - sol = idstn(dst_u, type=1) - else: - dst_rhs = dst_I(rhs_interior, xp=xp).ravel() - dst_u = (dst_rhs / op).reshape((nx, ny, nz)) - sol = idst_I(dst_u, xp=xp) + if gauge is None: + gauge = xp.mean(rhs) + fft_rhs = xp.fft.fftn(rhs) + fft_rhs[0, 0, 0] = gauge # gauge invariance + sol = xp.fft.ifftn(fft_rhs / op).real return sol -def project_vector_field_divergence(vector_field, spacings=(1, 1, 1), xp=np): +def project_vector_field_divergence_periodic_3D(vector_field, xp=np): """ Returns solenoidal part of vector field using projection: f - \\grad{p} s.t. \\laplacian{p} = \\div{f} """ - - div_v = compute_divergence(vector_field, spacings, xp=xp) - p = preconditioned_poisson_solver(div_v, spacings[0], xp=xp) - grad_p = compute_gradient(p, spacings, xp=xp) + spacings = (1, 1, 1) + div_v = compute_divergence_periodic(vector_field, spacings, xp=xp) + p = preconditioned_poisson_solver_periodic_3D(div_v, xp=xp) + grad_p = compute_gradient_periodic(p, spacings, xp=xp) return vector_field - grad_p @@ -1611,25 +1577,226 @@ def aberrations_basis_function( return aberrations_basis, aberrations_mn +def interleave_ndarray_symmetrically(array_nd, axis, xp=np): + """[a,b,c,d,e,f] -> [a,c,e,f,d,b]""" + array_shape = np.array(array_nd.shape) + d = array_nd.ndim + n = array_shape[axis] + + array = xp.empty_like(array_nd) + array[array_slice(axis, d, None, (n - 1) // 2 + 1)] = array_nd[ + array_slice(axis, d, None, None, 2) + ] + + if n % 2: # odd + array[array_slice(axis, d, (n - 1) // 2 + 1, None)] = array_nd[ + array_slice(axis, d, -2, None, -2) + ] + else: # even + array[array_slice(axis, d, (n - 1) // 2 + 1, None)] = array_nd[ + array_slice(axis, d, None, None, -2) + ] + + return array + + +def return_exp_factors(size, ndim, axis, xp=np): + none_axes = [None] * ndim + none_axes[axis] = slice(None) + exp_factors = 2 * xp.exp(-1j * np.pi * xp.arange(size) / (2 * size)) + return exp_factors[tuple(none_axes)] + + +def dct_II_using_FFT_base(array_nd, xp=np): + """FFT-based DCT-II""" + d = array_nd.ndim + + for axis in range(d): + n = array_nd.shape[axis] + interleaved_array = interleave_ndarray_symmetrically(array_nd, axis=axis, xp=xp) + exp_factors = return_exp_factors(n, d, axis, xp) + interleaved_array = xp.fft.fft(interleaved_array, axis=axis) + interleaved_array *= exp_factors + array_nd = interleaved_array.real + + return array_nd + + +def dct_II_using_FFT(array_nd, xp=np): + if xp.iscomplexobj(array_nd): + real = dct_II_using_FFT_base(array_nd.real, xp=xp) + imag = dct_II_using_FFT_base(array_nd.imag, xp=xp) + return real + 1j * imag + else: + return dct_II_using_FFT_base(array_nd, xp=xp) + + +def interleave_ndarray_symmetrically_inverse(array_nd, axis, xp=np): + """[a,c,e,f,d,b] -> [a,b,c,d,e,f]""" + array_shape = np.array(array_nd.shape) + d = array_nd.ndim + n = array_shape[axis] + + array = xp.empty_like(array_nd) + array[array_slice(axis, d, None, None, 2)] = array_nd[ + array_slice(axis, d, None, (n - 1) // 2 + 1) + ] + + if n % 2: # odd + array[array_slice(axis, d, -2, None, -2)] = array_nd[ + array_slice(axis, d, (n - 1) // 2 + 1, None) + ] + else: # even + array[array_slice(axis, d, None, None, -2)] = array_nd[ + array_slice(axis, d, (n - 1) // 2 + 1, None) + ] + + return array + + +def return_exp_factors_inverse(size, ndim, axis, xp=np): + none_axes = [None] * ndim + none_axes[axis] = slice(None) + exp_factors = xp.exp(1j * np.pi * xp.arange(size) / (2 * size)) / 2 + return exp_factors[tuple(none_axes)] + + +def idct_II_using_FFT_base(array_nd, xp=np): + """FFT-based IDCT-II""" + d = array_nd.ndim + + for axis in range(d): + n = array_nd.shape[axis] + reversed_array = xp.roll( + array_nd[array_slice(axis, d, None, None, -1)], 1, axis=axis + ) # C(N-k) + reversed_array[array_slice(axis, d, 0, 1)] = 0 # set C(N) = 0 + + interleaved_array = array_nd - 1j * reversed_array + exp_factors = return_exp_factors_inverse(n, d, axis, xp) + interleaved_array *= exp_factors + + array_nd = xp.fft.ifft(interleaved_array, axis=axis).real + array_nd = interleave_ndarray_symmetrically_inverse(array_nd, axis=axis, xp=xp) + + return array_nd + + +def idct_II_using_FFT(array_nd, xp=np): + """FFT-based IDCT-II""" + if xp.iscomplexobj(array_nd): + real = idct_II_using_FFT_base(array_nd.real, xp=xp) + imag = idct_II_using_FFT_base(array_nd.imag, xp=xp) + return real + 1j * imag + else: + return idct_II_using_FFT_base(array_nd, xp=xp) + + +def preconditioned_laplacian_neumann_2D(shape, xp=np): + """DCT eigenvalues""" + n, m = shape + i, j = xp.ogrid[0:n, 0:m] + + op = 4 - 2 * xp.cos(np.pi * i / n) - 2 * xp.cos(np.pi * j / m) + op[0, 0] = 1 # gauge invariance + return -op + + +def preconditioned_poisson_solver_neumann_2D(rhs, gauge=None, xp=np): + """DCT based poisson solver""" + op = preconditioned_laplacian_neumann_2D(rhs.shape, xp=xp) + + if gauge is None: + gauge = xp.mean(rhs) + + if xp is np: + fft_rhs = dctn(rhs, type=2) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idctn(fft_rhs / op, type=2).real + else: + fft_rhs = dct_II_using_FFT(rhs, xp) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idct_II_using_FFT(fft_rhs / op, xp) + + return sol + + +def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np): + """Weigted phase unwrapping using DCT-based poisson solver""" + + if np.iscomplexobj(array): + raise ValueError() + + if corner_centered: + array = xp.fft.fftshift(array) + if weights is not None: + weights = xp.fft.fftshift(weights) + + dx = xp.mod(xp.diff(array, axis=0) + np.pi, 2 * np.pi) - np.pi + dy = xp.mod(xp.diff(array, axis=1) + np.pi, 2 * np.pi) - np.pi + + if weights is not None: + # normalize weights + weights -= weights.min() + weights /= weights.max() + + ww = weights**2 + dx *= xp.minimum(ww[:-1, :], ww[1:, :]) + dy *= xp.minimum(ww[:, :-1], ww[:, 1:]) + + rho = xp.diff(dx, axis=0, prepend=0, append=0) + rho += xp.diff(dy, axis=1, prepend=0, append=0) + + unwrapped_array = preconditioned_poisson_solver_neumann_2D(rho, gauge=gauge, xp=xp) + unwrapped_array -= unwrapped_array.min() + + if corner_centered: + unwrapped_array = xp.fft.ifftshift(unwrapped_array) + + return unwrapped_array + + +def unwrap_phase_2d_skimage(array, corner_centered=True, xp=np): + if xp is np: + array = array.astype(np.float64) + unwrapped_array = unwrap_phase(array, wrap_around=corner_centered).astype( + xp.float32 + ) + else: + array = xp.asnumpy(array).astype(np.float64) + unwrapped_array = unwrap_phase(array, wrap_around=corner_centered) + unwrapped_array = xp.asarray(unwrapped_array).astype(xp.float32) + + return unwrapped_array + + def fit_aberration_surface( complex_probe, probe_sampling, energy, max_angular_order, max_radial_order, + use_scikit_image, xp=np, ): """ """ probe_amp = xp.abs(complex_probe) probe_angle = -xp.angle(complex_probe) - if xp is np: - probe_angle = probe_angle.astype(np.float64) - unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True).astype(xp.float32) + if use_scikit_image: + unwrapped_angle = unwrap_phase_2d_skimage( + probe_angle, + corner_centered=True, + xp=xp, + ) + else: - probe_angle = xp.asnumpy(probe_angle).astype(np.float64) - unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) - unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) + unwrapped_angle = unwrap_phase_2d( + probe_angle, + weights=probe_amp, + corner_centered=True, + xp=xp, + ) raveled_basis, _ = aberrations_basis_function( complex_probe.shape, @@ -1647,6 +1814,8 @@ def fit_aberration_surface( coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) + angle_offset = fitted_angle[0, 0] - probe_angle[0, 0] + fitted_angle -= angle_offset return fitted_angle, coeff @@ -2077,6 +2246,67 @@ def lanczos_kernel_density_estimate( return pix_output +def vectorized_bilinear_resample( + array, + scale=None, + output_size=None, + mode="grid-wrap", + grid_mode=True, + xp=np, +): + """ + Resize an array along its final two axes. + Note, this is vectorized and thus very memory-intensive. + + The scaling of the array can be specified by passing either `scale`, which sets + the scaling factor along both axes to be scaled; or by passing `output_size`, + which specifies the final dimensions of the scaled axes. + + Parameters + ---------- + array: np.ndarray + Input array to be resampled + scale: float + Scalar value giving the scaling factor for all dimensions + output_size: (int,int) + Tuple of two values giving the output size for the final two axes + xp: Callable + Array computing module + + Returns + ------- + resampled_array: np.ndarray + Resampled array + """ + + array_size = np.array(array.shape) + input_size = array_size[-2:].copy() + + if scale is not None: + scale = np.array(scale) + if scale.size == 1: + scale = np.tile(scale, 2) + + output_size = (input_size * scale).astype("int") + else: + if output_size is None: + raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) + if output_size.size != 2: + raise ValueError("`output_size` must contain exactly two values.") + output_size = np.array(output_size) + + scale_output = tuple(output_size / input_size) + scale_output = (1,) * (array_size.size - input_size.size) + scale_output + + if xp is np: + array = zoom(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + else: + array = zoom_cp(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + + return array + + def vectorized_fourier_resample( array, scale=None, @@ -2122,6 +2352,7 @@ def vectorized_fourier_resample( else: if output_size is None: raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) if output_size.size != 2: raise ValueError("`output_size` must contain exactly two values.") output_size = np.array(output_size) @@ -2244,3 +2475,29 @@ def vectorized_fourier_resample( array_resize *= scale_output return array_resize + + +def partition_list(lst, size): + """Partitions lst into chunks of size. Returns a generator.""" + for i in range(0, len(lst), size): + yield lst[i : i + size] + + +def copy_to_device(array, device="cpu"): + """Copies array to device. Default allows one to use this as asnumpy()""" + xp = get_array_module(array) + + if xp is np: + if device == "cpu": + return np.asarray(array) + elif device == "gpu": + return cp.asarray(array) + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + else: + if device == "cpu": + return cp.asnumpy(array) + elif device == "gpu": + return cp.asarray(array) + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 8462eec7d..7430992e0 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -76,6 +76,7 @@ def show( theta=None, title=None, show_fft=False, + apply_hanning_window=True, show_cbar=False, **kwargs, ): @@ -305,6 +306,8 @@ def show( which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. show_fft (bool): if True, plots 2D-fft of array + apply_hanning_window (bool) + If True, a 2D Hann window is applied to the array before applying the FFT show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() @@ -369,9 +372,12 @@ def show( from py4DSTEM.visualize import show if show_fft: - n0 = ar.shape - w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] - ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + if apply_hanning_window: + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + else: + ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) for a0 in range(num_images): im = show( ar[a0], @@ -451,7 +457,12 @@ def show( # Otherwise, plot one image if show_fft: if combine_images is False: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + if apply_hanning_window: + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + else: + ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) # get image from a masked array if mask is not None: diff --git a/py4DSTEM/visualize/vis_grid.py b/py4DSTEM/visualize/vis_grid.py index d24b0b8d8..9be754689 100644 --- a/py4DSTEM/visualize/vis_grid.py +++ b/py4DSTEM/visualize/vis_grid.py @@ -205,7 +205,7 @@ def show_image_grid( ax = axs[i, j] N = i * W + j # make titles - if type(title) == list: + if type(title) == list and N < len(title): print_title = title[N] else: print_title = None diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 88b7d7815..c8b8a8b12 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -801,22 +801,31 @@ def show_complex( ) if scalebar is True: scalebar = { - "Nx": ar_complex[0].shape[0], - "Ny": ar_complex[0].shape[1], + "Nx": rgb[0].shape[0], + "Ny": rgb[0].shape[1], "pixelsize": pixelsize, "pixelunits": pixelunits, } add_scalebar(ax[0, 0], scalebar) else: + figsize = kwargs.pop("axsize", None) + figsize = kwargs.pop("figsize", figsize) + fig, ax = show( - rgb, vmin=0, vmax=1, intensity_range="absolute", returnfig=True, **kwargs + rgb, + vmin=0, + vmax=1, + intensity_range="absolute", + returnfig=True, + figsize=figsize, + **kwargs, ) if scalebar is True: scalebar = { - "Nx": ar_complex.shape[0], - "Ny": ar_complex.shape[1], + "Nx": rgb.shape[0], + "Ny": rgb.shape[1], "pixelsize": pixelsize, "pixelunits": pixelunits, } @@ -826,7 +835,7 @@ def show_complex( # add color bar if cbar: if is_grid: - for ax_flat in ax.flatten(): + for ax_flat in ax.flatten()[: len(rgb)]: divider = make_axes_locatable(ax_flat) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) diff --git a/setup.py b/setup.py index 58f61dc4a..5c662dcf0 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ author_email="ben.savitzky@gmail.com", license="GNU GPLv3", keywords="STEM 4DSTEM", - python_requires=">=3.9,<3.13", + python_requires=">=3.10", install_requires=[ "numpy >= 1.19", "scipy >= 1.5.2",